{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/home/ljw/code/nlp/dialogue/transformer\r\n"
     ]
    }
   ],
   "source": [
    "!pwd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "from config import *\n",
    "from transformer import TransformerEncoder, TransformerDecoder, EncoderDecoder\n",
    "from attention import sequence_mask\n",
    "from data import loaddata, MyDataset, Lang"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Timer:\n",
    "    \"\"\"Record multiple running times.\"\"\"\n",
    "    def __init__(self):\n",
    "        self.times = []\n",
    "        self.start()\n",
    "\n",
    "    def start(self):\n",
    "        \"\"\"Start the timer.\"\"\"\n",
    "        self.tik = time.time()\n",
    "\n",
    "    def stop(self):\n",
    "        \"\"\"Stop the timer and record the time in a list.\"\"\"\n",
    "        self.times.append(time.time() - self.tik)\n",
    "        return self.times[-1]\n",
    "\n",
    "    def avg(self):\n",
    "        \"\"\"Return the average time.\"\"\"\n",
    "        return sum(self.times) / len(self.times)\n",
    "\n",
    "    def sum(self):\n",
    "        \"\"\"Return the sum of time.\"\"\"\n",
    "        return sum(self.times)\n",
    "\n",
    "    def cumsum(self):\n",
    "        \"\"\"Return the accumulated time.\"\"\"\n",
    "        return np.array(self.times).cumsum().tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "import time\n",
    "from IPython import display\n",
    "    \n",
    "def set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend):\n",
    "    \"\"\"Set the axes for matplotlib.\"\"\"\n",
    "    axes.set_xlabel(xlabel)\n",
    "    axes.set_ylabel(ylabel)\n",
    "    axes.set_xscale(xscale)\n",
    "    axes.set_yscale(yscale)\n",
    "    axes.set_xlim(xlim)\n",
    "    axes.set_ylim(ylim)\n",
    "    if legend:\n",
    "        axes.legend(legend)\n",
    "    axes.grid()\n",
    "\n",
    "class Animator:\n",
    "    \"\"\"For plotting data in animation.\"\"\"\n",
    "    def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,\n",
    "                 ylim=None, xscale='linear', yscale='linear',\n",
    "                 fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,\n",
    "                 figsize=(3.5, 2.5)):\n",
    "        # Incrementally plot multiple lines\n",
    "        if legend is None:\n",
    "            legend = []\n",
    "#         use_svg_display()\n",
    "        self.fig, self.axes = plt.subplots(nrows, ncols, figsize=figsize)\n",
    "        if nrows * ncols == 1:\n",
    "            self.axes = [self.axes,]\n",
    "        # Use a lambda function to capture arguments\n",
    "        self.config_axes = lambda: set_axes(self.axes[\n",
    "            0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)\n",
    "        self.X, self.Y, self.fmts = None, None, fmts\n",
    "\n",
    "    def add(self, x, y):\n",
    "        # Add multiple data points into the figure\n",
    "        if not hasattr(y, \"__len__\"):\n",
    "            y = [y]\n",
    "        n = len(y)\n",
    "        if not hasattr(x, \"__len__\"):\n",
    "            x = [x] * n\n",
    "        if not self.X:\n",
    "            self.X = [[] for _ in range(n)]\n",
    "        if not self.Y:\n",
    "            self.Y = [[] for _ in range(n)]\n",
    "        for i, (a, b) in enumerate(zip(x, y)):\n",
    "            if a is not None and b is not None:\n",
    "                self.X[i].append(a)\n",
    "                self.Y[i].append(b)\n",
    "        self.axes[0].cla()\n",
    "        for x, y, fmt in zip(self.X, self.Y, self.fmts):\n",
    "            self.axes[0].plot(x, y, fmt)\n",
    "        self.config_axes()\n",
    "        display.display(self.fig)\n",
    "        display.clear_output(wait=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn.functional as F\n",
    "def loss_fn(out, tar):\n",
    "    out = out.contiguous().view(-1, out.shape[-1])\n",
    "    tar = tar.contiguous().view(-1)\n",
    "    return F.cross_entropy(out, tar, ignore_index=2)  # pad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def grad_clipping(net, theta):\n",
    "    \"\"\"Clip the gradient.\"\"\"\n",
    "    if isinstance(net, nn.Module):\n",
    "        params = [p for p in net.parameters() if p.requires_grad]\n",
    "    else:\n",
    "        params = net.params\n",
    "    norm = torch.sqrt(sum(torch.sum((p.grad**2)) for p in params))\n",
    "    if norm > theta:\n",
    "        for param in params:\n",
    "            param.grad[:] *= theta / norm\n",
    "            \n",
    "            \n",
    "def train_seq2seq(net, data_iter, valid_iter, lr, num_epochs, lang, device, first_train = True, min_ppl=1e9, i = 70):\n",
    "   # \"\"\"训练序列到序列模型\"\"\"\n",
    "    def xavier_init_weights(m):\n",
    "        if type(m) == nn.Linear:\n",
    "            nn.init.xavier_uniform_(m.weight)\n",
    "        if type(m) == nn.GRU:\n",
    "            for param in m._flat_weights_names:\n",
    "                if \"weight\" in param:\n",
    "                    nn.init.xavier_uniform_(m._parameters[param])\n",
    "    \n",
    "    net.to(device)\n",
    "    optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n",
    "    net.train()\n",
    "    animator = Animator(xlabel='epoch', ylabel='loss',\n",
    "                     xlim=[0, num_epochs])\n",
    "    animator1 = Animator(xlabel='epoch', ylabel='ppl',\n",
    "                     xlim=[0, num_epochs])\n",
    "    LOSS = PPL = 0\n",
    "    \n",
    "    timer = Timer()\n",
    "    \n",
    "    if first_train:\n",
    "        net.apply(xavier_init_weights)\n",
    "    else:\n",
    "        net.eval()\n",
    "        PPL_AVG = 0\n",
    "        LOSS_AVG = 0\n",
    "        with torch.no_grad():\n",
    "#             i = 0\n",
    "            for batch in valid_iter:\n",
    "                # 转device\n",
    "                X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch]\n",
    "                # inputs带<eos>不带<bos>，outputs带<bos>不带<eos>\n",
    "                # 初始bos\n",
    "                bos = torch.tensor([2  # lang.word2index['<bos>']\n",
    "                                    ] * Y.shape[0],\n",
    "                                   device=device).reshape(-1, 1)\n",
    "                dec_input = torch.cat([bos, Y[:, :-1]], 1)\n",
    "                # Teacher forcing, 输入X和正确答案，希望得到正确答案\n",
    "                # Y_hat\n",
    "                Y_hat, _ = net(X, dec_input, X_valid_len)\n",
    "                # Y_hat和Y中都有eos，没有bos\n",
    "                l = loss_fn(Y_hat, Y)\n",
    "                ppl = torch.exp(l)\n",
    "                LOSS_AVG += l.item()\n",
    "                PPL_AVG += ppl.item()\n",
    "            PPL_AVG /= len(valid_iter)\n",
    "            min_ppl = PPL_AVG\n",
    "            LOSS_AVG /= len(valid_iter)\n",
    "            \n",
    "            animator1.add(0, (PPL_AVG, PPL_AVG,))\n",
    "            animator.add(0, (LOSS_AVG, LOSS_AVG,))\n",
    "                \n",
    "    \n",
    "    for epoch in range(num_epochs):\n",
    "        \n",
    "        net.train()\n",
    "#         lr = 512**(-0.5) * i**(-0.5)\n",
    "        \n",
    "        for batch in data_iter:\n",
    "            optimizer.zero_grad()\n",
    "            X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch]\n",
    "            bos = torch.tensor([2] * Y.shape[0],\n",
    "                          device=device).reshape(-1, 1)\n",
    "            dec_input = torch.cat([bos, Y[:, :-1]], 1)  # 强制教学\n",
    "            Y_hat, _ = net(X, dec_input, X_valid_len)\n",
    "            l = loss_fn(Y_hat, Y)\n",
    "            ppl = torch.exp(l)\n",
    "            l.backward()\n",
    "#             grad_clipping(net, 20)\n",
    "            optimizer.step()\n",
    "            with torch.no_grad():\n",
    "                LOSS += l.item()\n",
    "                PPL += ppl.item()\n",
    "        i += 1\n",
    "#             if i == 30:\n",
    "#                 break\n",
    "        LOSS /= len(data_iter)\n",
    "        PPL /= len(data_iter)\n",
    "\n",
    "        \n",
    "        net.eval()\n",
    "        PPL_AVG = 0\n",
    "        LOSS_AVG = 0\n",
    "        with torch.no_grad():\n",
    "#             i = 0\n",
    "            for batch in valid_iter:\n",
    "                # 转device\n",
    "                X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch]\n",
    "                # inputs带<eos>不带<bos>，outputs带<bos>不带<eos>\n",
    "                # 初始bos\n",
    "                bos = torch.tensor([2  # lang.word2index['<bos>']\n",
    "                                    ] * Y.shape[0],\n",
    "                                   device=device).reshape(-1, 1)\n",
    "                dec_input = torch.cat([bos, Y[:, :-1]], 1)\n",
    "                # Teacher forcing, 输入X和正确答案，希望得到正确答案\n",
    "                # Y_hat\n",
    "                Y_hat, _ = net(X, dec_input, X_valid_len)\n",
    "                # Y_hat和Y中都有eos，没有bos\n",
    "                l = loss_fn(Y_hat, Y)\n",
    "                ppl = torch.exp(l)\n",
    "                LOSS_AVG += l.item()\n",
    "                PPL_AVG += ppl.item()\n",
    "#                 i += 1\n",
    "#                 if i == 10:\n",
    "#                     break\n",
    "                    \n",
    "            PPL_AVG /= len(valid_iter)\n",
    "            LOSS_AVG /= len(valid_iter)\n",
    "            if PPL_AVG < min_ppl:\n",
    "                if min_ppl != 1e9:\n",
    "                    torch.save(encoder, MODEL_ROOT + \"trans_encoder.mdl\")\n",
    "                    torch.save(decoder, MODEL_ROOT + \"trans_decoder.mdl\")\n",
    "                min_ppl = PPL_AVG\n",
    "                \n",
    "        if (epoch + 1) % 1 == 0:\n",
    "            \n",
    "            animator1.add(epoch + 1, (PPL, PPL_AVG,))\n",
    "            animator.add(epoch + 1, (LOSS, LOSS_AVG,))\n",
    "            \n",
    "    print(f'loss {LOSS:.3f}, PPL {PPL:.3f}, {timer.stop()/60:.1f} '\n",
    "        f'min on {str(device)}')\n",
    "    print(f'loss {LOSS_AVG:.3f}, PPL {PPL_AVG:.3f} ')\n",
    "    print(lr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = torch.load(DATA_ROOT + \"dataset\")\n",
    "lang = torch.load(DATA_ROOT + \"dialog.lang\")\n",
    "valid_data = torch.load(DATA_ROOT + \"validdata\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_iter = loaddata(dataset, 64)\n",
    "valid_iter = loaddata(valid_data, 64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "encoder = TransformerEncoder(\n",
    "        len(lang), key_size, query_size, value_size, num_hiddens,\n",
    "        norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,\n",
    "        num_layers, dropout)\n",
    "decoder = TransformerDecoder(\n",
    "        len(lang), key_size, query_size, value_size, num_hiddens,\n",
    "        norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,\n",
    "        num_layers, dropout)\n",
    "net = EncoderDecoder(encoder, decoder)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "ename": "FileNotFoundError",
     "evalue": "[Errno 2] No such file or directory: '../../model/dialog/trans_encoder.mdl'",
     "output_type": "error",
     "traceback": [
      "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[0;31mFileNotFoundError\u001B[0m                         Traceback (most recent call last)",
      "\u001B[0;32m<ipython-input-10-41bf1858bf89>\u001B[0m in \u001B[0;36m<module>\u001B[0;34m\u001B[0m\n\u001B[0;32m----> 1\u001B[0;31m \u001B[0mencoder\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mtorch\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mload\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mMODEL_ROOT\u001B[0m \u001B[0;34m+\u001B[0m \u001B[0;34m\"trans_encoder{}.mdl\"\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mformat\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mnum_examples\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m      2\u001B[0m \u001B[0mdecoder\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mtorch\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mload\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mMODEL_ROOT\u001B[0m \u001B[0;34m+\u001B[0m \u001B[0;34m\"trans_decoder{}.mdl\"\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mformat\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mnum_examples\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m      3\u001B[0m \u001B[0mnet\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mEncoderDecoder\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mencoder\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mdecoder\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
      "\u001B[0;32m~/anaconda3/envs/py37/lib/python3.6/site-packages/torch/serialization.py\u001B[0m in \u001B[0;36mload\u001B[0;34m(f, map_location, pickle_module, **pickle_load_args)\u001B[0m\n\u001B[1;32m    592\u001B[0m         \u001B[0mpickle_load_args\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0;34m'encoding'\u001B[0m\u001B[0;34m]\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0;34m'utf-8'\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m    593\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 594\u001B[0;31m     \u001B[0;32mwith\u001B[0m \u001B[0m_open_file_like\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mf\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;34m'rb'\u001B[0m\u001B[0;34m)\u001B[0m \u001B[0;32mas\u001B[0m \u001B[0mopened_file\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m    595\u001B[0m         \u001B[0;32mif\u001B[0m \u001B[0m_is_zipfile\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mopened_file\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m    596\u001B[0m             \u001B[0;31m# The zipfile reader is going to advance the current file position.\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
      "\u001B[0;32m~/anaconda3/envs/py37/lib/python3.6/site-packages/torch/serialization.py\u001B[0m in \u001B[0;36m_open_file_like\u001B[0;34m(name_or_buffer, mode)\u001B[0m\n\u001B[1;32m    228\u001B[0m \u001B[0;32mdef\u001B[0m \u001B[0m_open_file_like\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mname_or_buffer\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mmode\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m    229\u001B[0m     \u001B[0;32mif\u001B[0m \u001B[0m_is_path\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mname_or_buffer\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 230\u001B[0;31m         \u001B[0;32mreturn\u001B[0m \u001B[0m_open_file\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mname_or_buffer\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mmode\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m    231\u001B[0m     \u001B[0;32melse\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m    232\u001B[0m         \u001B[0;32mif\u001B[0m \u001B[0;34m'w'\u001B[0m \u001B[0;32min\u001B[0m \u001B[0mmode\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
      "\u001B[0;32m~/anaconda3/envs/py37/lib/python3.6/site-packages/torch/serialization.py\u001B[0m in \u001B[0;36m__init__\u001B[0;34m(self, name, mode)\u001B[0m\n\u001B[1;32m    209\u001B[0m \u001B[0;32mclass\u001B[0m \u001B[0m_open_file\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0m_opener\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m    210\u001B[0m     \u001B[0;32mdef\u001B[0m \u001B[0m__init__\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mname\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mmode\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 211\u001B[0;31m         \u001B[0msuper\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0m_open_file\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mself\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0m__init__\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mopen\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mname\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mmode\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m    212\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m    213\u001B[0m     \u001B[0;32mdef\u001B[0m \u001B[0m__exit__\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mself\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;34m*\u001B[0m\u001B[0margs\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
      "\u001B[0;31mFileNotFoundError\u001B[0m: [Errno 2] No such file or directory: '../../model/dialog/trans_encoder.mdl'"
     ]
    }
   ],
   "source": [
    "encoder = torch.load(MODEL_ROOT + \"trans_encoder{}.mdl\".format(num_examples))\n",
    "decoder = torch.load(MODEL_ROOT + \"trans_decoder{}.mdl\".format(num_examples))\n",
    "net = EncoderDecoder(encoder, decoder)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "lr = 0.002"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[0;31mKeyboardInterrupt\u001B[0m                         Traceback (most recent call last)",
      "\u001B[0;32m<ipython-input-16-13b1c991e5a4>\u001B[0m in \u001B[0;36m<module>\u001B[0;34m\u001B[0m\n\u001B[0;32m----> 1\u001B[0;31m \u001B[0mtrain_seq2seq\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mnet\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mtrain_iter\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mvalid_iter\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;36m0.1\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;36m10\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mlang\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mdevice\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mfirst_train\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;32mFalse\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m",
      "\u001B[0;32m<ipython-input-6-d48a01a18e20>\u001B[0m in \u001B[0;36mtrain_seq2seq\u001B[0;34m(net, data_iter, valid_iter, lr, num_epochs, lang, device, first_train, min_ppl, i)\u001B[0m\n\u001B[1;32m     77\u001B[0m             \u001B[0mdec_input\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mtorch\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mcat\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0mbos\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mY\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;34m:\u001B[0m\u001B[0;34m-\u001B[0m\u001B[0;36m1\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;36m1\u001B[0m\u001B[0;34m)\u001B[0m  \u001B[0;31m# 强制教学\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m     78\u001B[0m             \u001B[0mY_hat\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0m_\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mnet\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mX\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mdec_input\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mX_valid_len\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m---> 79\u001B[0;31m             \u001B[0ml\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mloss_fn\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mY_hat\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mY\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m     80\u001B[0m             \u001B[0mppl\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mtorch\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mexp\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0ml\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m     81\u001B[0m             \u001B[0ml\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mbackward\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
      "\u001B[0;32m<ipython-input-5-ae441a1c19cf>\u001B[0m in \u001B[0;36mloss_fn\u001B[0;34m(out, tar)\u001B[0m\n\u001B[1;32m      1\u001B[0m \u001B[0;32mimport\u001B[0m \u001B[0mtorch\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mnn\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mfunctional\u001B[0m \u001B[0;32mas\u001B[0m \u001B[0mF\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m      2\u001B[0m \u001B[0;32mdef\u001B[0m \u001B[0mloss_fn\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mout\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mtar\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m----> 3\u001B[0;31m     \u001B[0mout\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mout\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mcontiguous\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mview\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m-\u001B[0m\u001B[0;36m1\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mout\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mshape\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0;34m-\u001B[0m\u001B[0;36m1\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m      4\u001B[0m     \u001B[0mtar\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mtar\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mcontiguous\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mview\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m-\u001B[0m\u001B[0;36m1\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m      5\u001B[0m     \u001B[0;32mreturn\u001B[0m \u001B[0mF\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mcross_entropy\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mout\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mtar\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mignore_index\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;36m2\u001B[0m\u001B[0;34m)\u001B[0m  \u001B[0;31m# pad\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
      "\u001B[0;31mKeyboardInterrupt\u001B[0m: "
     ]
    },
    {
     "data": {
      "text/plain": "<Figure size 252x180 with 1 Axes>",
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQAAAAC1CAYAAACwAiEUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAbc0lEQVR4nO3de3xU9Z3w8c93JheSSSQkBETAAkpSNIAC+qBUikVRtype6l3Wx+VVdn3YFbe2Vts+3d0+j6+t223r2q1V6r3yYPESpLXeqiAVRbkIEkBqRIRIJJmEBHJPZr7PHzPRMeQymczMOcl8369XXsn85nfO+SavnO+c8zu/i6gqxpjU5HE6AGOMcywBGJPCLAEYk8IsARiTwiwBGJPCLAEYk8LSErVjEXkEuBioUtWScNnPgEuANuAj4GZVrQu/dxewGAgAt6rqy30dIy8vT08++eTE/AID1NjYiM/nczqMbrk5NnB3fG6ObcuWLX5VLezXRqqakC9gLjADKIsoWwCkhX++B7gn/PMpwHYgE5hIKDl4+zpGUVGRutXatWudDqFHbo5N1d3xuTk2YLP28zxN2C2Aqq4HaruUvaKqHeGXG4Fx4Z8XAk+paquqfgyUA2cmKjZjTIiTbQB/B7wY/nkscCDivYpwmTEmgRLWBtAbEfkh0AGs6Czqplq3fZRFZAmwBKCwsJB169YlIsQBa2hosNhi5Ob43BxbLJKeAETkJkKNg/PD9y0Q+sQfH1FtHHCwu+1VdTmwHKC4uFjnzZuXuGBj1HKghY23b2Tuirl40t33oGXdunW48e/Wyc3xuTm2WCT1v1NELgS+D1yqqk0Rb60BrhWRTBGZCEwG3k1mbPG0+/rd8DQc2XjE6VCM6VXCEoCIrATeBopFpEJEFgP/DeQCr4rINhF5AEBVdwKrgF3AS8BSVQ0kKrZEUlXq36oHoOmDpj5qG+OshN0CqOp13RQ/3Ev9u4G7ExVPMk1dM5UdF++gsazR6VCM6ZUjjYBDmYhQ8M0C+BV85aqvOB2OMb1yXwvVIFdxXwWNuxqhBDJGZzgdjjG9sgQQR80fN1O+rJzaF2uhEvb/bD/tte1Oh2VMjywBxJG/1A/AyMtHwkHYe8deGt5rcDgqY3pmCSCO/KV+fNN8ZE3KCo1oAGsINK5mCSBO2g61Ub+hnsIrwoOxRkD6yHQadtgVgHEvSwBx0rC9AU+WJ3T5DyDgm+qzKwDjapYA4iR/QT5zqufgm/rFWHFfiY/mvzbzRY9nY9zFEkAcdJ7g3mwvIl+Ma5rwkwmc/dnZXyozxk0sAcRB1VNVbDptE60HW79Unp6XjifD/sTGvey/Mw78z/lpP9ROxvFf7vijQaX8n8s5tPKQQ5EZ0ztLAAMUaA5Q82INBQsLEM+XL/XFI1SXVlPzhxqHojOmd5YABujwnw8TbAx+8fivi5ypOfYkwLiWJYAB8j/nxzvcS968vG7f95X4aPqgiWB7MLmBGRMFGw04QPkX5eOb5uuxsc831Ye2K81/bcZ3qjunkzapyxLAAI26elSv7/tKfGSMyaCtqs0SgHEdSwADULe+jmGThjFs3LAe6+RMy+Hsg2cnMSpjomdtADFSVXbfsJsP//FDp0MxJmaWAGJ0dPNRWitaKby875WYDtx7gO0LtichKmP6xxJAjPylfvBCwSUFfdYNNAQ4/OphOho6+qxrTDJZAohR9XPV5M3LIz0/vc+6vpJQ41/TLpsl2LiLJYAYNO9rpvmvzVFd/sMXCcA6BBm3sacAMciakMVZB8/Cm+WNrv6kLDxZHhp3WAIw7pLIhUEeEZEqESmLKLtKRHaKSFBEZnWpf5eIlIvIHhG5IFFxxUvm8ZmkDY8uf4pHKLyqkIyxNkuwcZdE3gI8BlzYpawMuAJYH1koIqcA1wKnhre5X0Si+3hNspb9Lbx/0fsc3Xa0X9tNeXwKJ373xARFZUxsEpYAVHU9UNulbLeq7umm+kLgKVVtVdWPgXLgzETFNhD+1X5qX6rFm93//KSqaNBmBzLu4ZZGwLHAgYjXFeEy16l+rprsU7PJLsru13ZH3jnCm8PfpO6NusQEZkwM3NII2N2cWd1+VIrIEmAJQGFhYXLXaq8D/gLcQJ/HPWYd+RrgKGwv3d79b5tEbl/j3s3xuTm2WLglAVQA4yNejwMOdldRVZcDywGKi4s1mWu1Vz5SyZ7gHmbeNpPcGbm91u26jryqsqFgA4WthRTPK05wpL1z+xr3bo7PzbHFwi23AGuAa0UkU0QmApOBdx2O6Rje47wUXFpAzuk5/d5WRPCV+OxRoHGVRD4GXAm8DRSLSIWILBaRy0WkAjgLeEFEXgZQ1Z3AKmAX8BKwVFUDiYotVqO+NYqpz0+NeZZfX0lonQCbJty4RcJuAVT1uh7eKu2h/t3A3YmKZ6BaKlpIy0sjLSf2P9nIS0aSlpdGsDWId5grn3KaFOOWNgDX2/u9vdS/Vc/sfbNjvgLIvyCf/Avy4xyZMbFzSxuAqwVbg9S8UEP+gvwBL/LRUd9xzPoBxjjFEkAUDr92mMDRwBfr/g3Apmmb+OiOj+IQlTEDZwkgCv5SP95cLyPmjxjwvjobAo1xA0sAfdCA4n/eT8E3C/BkDvzP5Svx0bTbpgk37mCNgH3xwPQ/T49bqvSV+NA2pbm8Gd8UmyXYOMsSQB9EhJxp/e/405PO5cMbdzRaAjCOs1uAXqgq5d8p58g7R+K2z+yvZjP5/snkntF7V2JjksESQC8a3mug4pcVNO6MX6Odd5iXsbeMJWtiVtz2aUysLAH0wl/qB090M//2R8uBFmpesBWDjfMsAfSiurSavLl5ZBTGdyqvzx79jB2X7CDQ5LrhDibFWALoQdNfm2ja2RSXzj9d+Up8oNC4y/oDGGdZAuhBy/4WMsdnMvKyBCUAbJpw4zx7DNiD/PPymf1J7AN/epN1UhaeYTZNuHGeXQF0I9gaRIOakJMfQLxC9inZdgVgHGcJoBuVD1Xy1glv0VbVlrBjFD9UTNGDRQnbvzHRsFuAblSXVpOWl0bGqMQt5JF7unUEMs6zK4Au2mvbqVtXR+EV0a37F/Nxatqp+K8KGj+w2wDjHEsAXdT8oQYCJOTxX6RgS5Dy28qpe60uoccxpjeWALqoLq0mc1wmubMSe4mecUIGaXlp1hBoHBVVAhCRZSJynIQ8LCJbRWRBooNzwrhbx3HSf56UsCcAnUQE31QfDTsaEnocY3oT7RXA36nqEWABUAjcDPw0YVE5aMQ3RjDqmlFJOZZNE26cFm0C6Pw4/BvgUVXdjuMLXMVf1e+rOPpe/1b9HQjfVB+BIwHaPkvc40ZjehNtAtgiIq8QSgAvi0gu0OucViLyiIhUiUhZRFm+iLwqIh+Gv4+IeO8uESkXkT0ickEsv8xABNuC7Pn7PXx636dJO+boRaM5p+EcMsdkJu2YxkSKNgEsBu4EzlDVJiCd0G1Abx4DLuxSdifwmqpOBl4Lv0ZETgGuBU4Nb3O/iCR15Yy6tXUE6gOMvCKxrf+R0nLSYlpm3Jh4iTYBnAXsUdU6EbkR+BFQ39sGqroeqO1SvBB4PPzz48BlEeVPqWqrqn4MlANnRhlbXFSXVuPxeRhx/sBn/u2P/ffs58DPD/Rd0ZgEiLYn4G+A6SIyHbgDeBh4Avh6P483WlUrAVS1UkQ6W9vGAhsj6lWEy46RkOXBg8DTwCz4y8a/DHx/9GMZ6WeBOvhoZvLWCnD7Etdujs/NscVEVfv8AraGv/8YWBxZ1sd2E4CyiNd1Xd4/HP7+a+DGiPKHgSv72n9RUZHGQ+OeRn3D94Z+tuKzuOxPVXXt2rVR1Sv/brmuy1yngfZA3I7dl2hjc4qb43NzbMBmjeJ8jvyK9grgqIjcBSwCzgnfn6fHkG8OicgYDX36jwGqwuUVwPiIeuOAgzHsPybZRdnM8c9J+LP/7vhKfGir0vJRC9nF2Uk/vklt0bYBXAO0EuoP8Bmhy/OfxXC8NcBN4Z9vAp6PKL9WRDJFZCIwGXg3hv3HzDvMG5eFP/qrc5pw6xBknBDVf3z4pF8BDBeRi4EWVX2it21EZCXwNlAsIhUisphQ56HzReRD4Pzwa1R1J7AK2AW8BCxV1aRMmNdQ1sCmaZs4sjl+U3/3R/aUbNIL0wkcsfkBTfJFdQsgIlcT+sRfR6gD0K9E5Huq+kxP26jqdT28Nb+H+ncDd0cTTzz5V/tp3NFI5lhnnsV7s7zMqZrjyLGNibYN4IeE+gBUAYhIIfBnoMcEMFj4S/0cN/s464xjUlK0N72ezpM/rKYf27pWy/4WGrY2JHzob1+qn61m02mbCDTbbYBJrmivAF4SkZeBleHX1wB/SkxIyeNf7QdIyMy//aGqNG5vpGl3E7kzbKYgkzxRJQBV/Z6IXAnMIdQGsFxVSxMaWRJkfzWbE5aeQHaRs4/fIqcJtwRgkinqOQFV9VlC/daGjPwF+eQvyHc6DLJOzkIyxaYJN0nXawIQkaNAd4PVBVBVPS4hUSVB4+5GvD4vw04c5nQoeNI8+Kb4bHYgk3S9JgBVHbLXo3vv2kvD1oaELf7RXwUXF9BxtMPpMEyKSclpwQONAQ6/fJgx3x7jipMfYOL/meh0CCYFDfpHebGofaWWYEvQ8db/rlQVDdj0YCZ5UjIB+Ev9pOWnMXzucKdD+VxbdRsbCjdwcHnSxkAZk3oJQANKzZ9qKLikAE+ae3799JHpaIdaQ6BJqpRrAxCvcObOM13X605EQrME26NAk0Tu+QhMoozRGWRNyHI6jGPYNOEm2VIqAagqu67fRe3LXacqdIecqTl0HO6grdKmCTfJkVIJ4Ojmo1StrKLtkDtPsOHnDGf898YPwRUXjFulVBuAf7UfvKFON26UMy2HnP/IcToMk0JS6grAX+on7+t5pOfHMp1hcgSaArR80uJ0GCZFpEwCaNrTRNPuJsfH/vel7PIyyq4s67uiMXGQMgmg/XA7uWfmMnKhuxOAr8RH064m6xFokiJlEsDw2cOZ+c5Mho13fvRfb3wlPoLNQZr3NjsdikkBKZEAOho6Bs1Iu85pwq1HoEmGlEgAh353iA2FG2g54P7GNd8UHwjWI9AkRUo8BvSv9jPsxGFkjnP/zL9en5fi3xaTe8aQnYrBuIgjVwAiskxEykRkp4jcFi7LF5FXReTD8Pe4LNPbXtdO3et1jLx8pGvG/vdlzOIx5Eyz/gAm8ZKeAESkBPg2oeW/pwMXi8hk4E7gNVWdDLwWfj1gtS/Uoh3q+sd/kdr8bfjX+Am2Bp0OxQxxTlwBTAE2qmqTqnYAbwCXAwuBx8N1Hgcui8fB/Kv9ZIzJ4LgzB8/0hXWv11G2sIymD5qcDsUMcZLskWciMoXQoqBnAc2EPu03A4tUNS+i3mFVPeY2QESWAEsACgsLZ65atar3A34MHAJmxyf+aDU0NJCTE+Nl/D7gZuAHhFZQjLMBxZYEbo7PzbGde+65W1R1Vr826u964vH4AhYDW4H1wAPAL4G6LnUO97WfoqKi/i6hnjQDWUc+0BbQdenrtPz75fELKIKb17hXdXd8bo4N2Kz9PBcdaQRU1YdVdYaqzgVqgQ+BQyIyBiD8vaq3fUSj8tFKal6sGehuks6T7iF7SrY9CjQJ59RTgFHh7ycCVxBacmwNcFO4yk2EbhNipgFl7/f3cuh3hwayG8d0Tg5iTCI51Q/gWREpANqBpap6WER+CqwSkcXAfuCqgRygfkM97dXtg6r1P9KEH0+wmYFMwjmSAFT1nG7KaoD58TqGv9SPZAr5F/a99Jeqsnz9XnYePEK610NGmpDh9ZDu9ZCeFvqe4ZXQ63BZ5OuMNM8X9b1CepqHqqaBPcLLLnZ2vUKTGoZkT0BVxb/az4jzRpCW2/ev+OiGffz7ix8wNi80T2BbIEh7IEh7R5D2gNIWiO1kLmsv466LppCV4e33tsH2IJUPVeKb6iPva3kxHd+YvgzJBNB2qA0NKoWXF/ZZd+PeGu7+027OP2U0D944E4/n2N6CqkpHUMNJQWkNBGgPaDhBBMMJI/R+W0fo9f97/T2eePsT3vqohnuvOY2Ssf1bg0DShL137OX4/3m8JQCTMEMyAWQen8nsfbP7HFN/sK6ZpSu28pWCbH5x9fRuT34ITdmdHr7kJwOg7xmFpDKTm86bye1Pb+Py+zfwnfOLWTJ3Et4ejtHdMa0h0CTakBwNqEFFRHpd+KOlPcA/PLmF1o4gyxfNIndY/KcJ+9rkkby0bC7nTRnNPS99wPW/3cinddGP8/dN9dGwo8EaA03CDLkE0Ly3mbeOf4vaV3qe+ltV+dHqMt6vqOcXV0/n5FGJ69k1wpfB/TfM4GffmkbZp/VceO96nt/2aVTb+kp8dNR0uHYWYzP4DbkE4F/tp726nazJPS/88eTGT3hmSwW3zp/MglOPT3hMIsJVs8bz4rK5FI3OZdlT21j21HvUN7f3ul3n5CDNe2x2IJMYQy8BlPrxTfeRNbH7BLBpXy3/9oddfOOro7ht/uSkxnZiQTa/XzKb75xfxB/fr+Sie9ezcW/PPRWHzxnO1+q+Rt7X85IXpEkpQyoBtFW1Ub+hvsdlvz+rb+GWJ7cybkQWv7zmtB4b/RIpzevh1vmTefaWs8lI83Ddbzfy0xc/oK3j2EeNngwPacOHZDutcYkhlQD8a/ygdPv4r7UjwC0rttDU1sHyv53F8Cxn1wY4bXweL9x6DteeMZ4H3viIy+/fQHnV0WPqVT5WSfl3yx2I0KSCIZUAcmfkcuIPTsQ3zXfMe/+6Zhfv7a/j51dNp2i0O6bb8mWm8e9XTGP5oplU1rfwzfve5Im3932p1b9hWwMHf3MQDdqTABN/Qy4BTLp70jFTf618dz8r393P/5p3EhdNHeNQdD1bcOrxvHTbOcyeVMCPn9/JzY9toupoaALTnKk5BJuCtHzs/glNzeAzZBJAw/YG6jfWH/NJuXX/Yf7l+Z3MLSrk9gXFDkXXt1G5w3js5jP4ycJTefujGi689y+8uusQvhKbJtwkzpBJAPvv2U/ZpWUQcf5XHW3hlie3MHp4Jvdde1rUvfCcIiL87VkT+OM/fY3jjxvGt5/YzH+U7wOgYUeDs8GZIWlIJIBgW5CaF2oouLQA8YZO8raOIEtXbOVIcwfLF80iLzvD4SijN3l0LquXzuHvvz6JFWUH2DwbqgbnqGbjckPiGVPd2joCRwJfevz3f1/YxaZ9h7nvutOZMmbwTAjaKSPNw10XTWFe0ShuP24be7L9nMkEp8MyQ8yQSADVpdV4fB5GnBeaQ/TpzQd44u1PWDJ3EpdOP8Hh6AbmrJMKePG2uXTEOCTZmN4M+gSgqtStraPgogK8w7y8X1HHD1eXMefkAu64wL2Nfv3hdJ8FM3QN+gQgIpzx/hm017bjb2jlH363hcKcTH513QzSvEOiicOYhBkSZ4gn04NnVDpLV2ylprGNBxfNJN83eBr9jHHKoL8C2HHpDgqvKmR5Xi3vfFzLL66e3u/Zd4xJVYM7AbRBzR9qOFDi4ZHgPm6eM4ErZoxzOipjBo3BfQsQ7hvzb037+R8T8/nB30xxNh5jBpnBnQCOwv7ximd0Bv99/YzQnH3GmKg5tTLQP4vIThEpE5GVIjJMRPJF5FUR+TD8/ZiFQY/RCptO7uCBG2dSmJuZhMiNGVqSngBEZCxwKzBLVUsAL3AtcCfwmqpOJrRi8J197aspU5m3dBLTx+clMGJjhi6nrpnTgCwRSQOygYPAQuDx8PuPA5f1tZPAGLjyypMSFaMxQ544MeW0iCwD7gaagVdU9QYRqVPVvIg6h1X1mNsAEVkCLAEoLCycuWrVqiRF3T9uXkfezbGBu+Nzc2znnnvuFlWd1a+N+rue+EC/gBHA60AhoRU2VgM3AnVd6h3ua19FRUWxLqWecG5eR97Nsam6Oz43xwZs1n6ej07cApwHfKyq1araDjwHnA0cEpExAOHvVQ7EZkxKcSIB7Admi0i2hObumg/sBtYAN4Xr3AQ870BsxqSUpPcEVNV3ROQZYCvQAbwHLAdygFUisphQkrgq2bEZk2ocaQSMFxE5CuxxOo4ejAT8TgfRAzfHBu6Oz82xFatqv6a8HtxjAWCP9rfVM0lEZLPFFhs3x+f22Pq7jfWdNSaFWQIwJoUN9gSw3OkAemGxxc7N8Q2p2AZ1I6AxZmAG+xWAMWYABm0CEJELRWSPiJSLSJ8jB5NFRMaLyFoR2R0e8rzM6Zi6EhGviLwnIn90OpZIIpInIs+IyAfhv99ZTsfUqbsh7A7G8oiIVIlIWURZ/4fTM0gTgIh4gV8DFwGnANeJyCnORvW5DuB2VZ0CzAaWuii2TssI9b50m/8CXlLVrwLTcUmMvQxhd8pjwIVdyvo9nB4GaQIAzgTKVXWvqrYBTxEaTuw4Va1U1a3hn48S+ice62xUXxCRccA3gYecjiWSiBwHzAUeBlDVNlWtczSoL+tuCLsjVHU9UNuluN/D6WHwJoCxwIGI1xW46CTrJCITgNOBdxwOJdK9wB2A25YamgRUA4+Gb08eEhGf00EBqOqnwH8S6qJeCdSr6ivORnWM0apaCaEPIWBUNBsN1gTQ3TK/rnqcISI5wLPAbap6xOl4AETkYqBKVbc4HUs30oAZwG9U9XSgkSgvYxMtfD+9EJgInAD4RORGZ6OKj8GaACqA8RGvx+HgJVlXIpJO6ORfoarPOR1PhDnApSKyj9Bt0zdE5ElnQ/pcBVChqp1XS88QSghu0NMQdjeJaTj9YE0Am4DJIjJRRDIINciscTgmAMJDnB8GdqvqL5yOJ5Kq3qWq41R1AqG/2euq6opPMlX9DDggIp0LOs4HdjkYUqSehrC7SUzD6QflYCBV7RCRfwReJtQi+4iq7nQ4rE5zgEXADhHZFi77gar+ybmQBo1/AlaEk/pe4GaH4wF6HcLuCBFZCcwDRopIBfAvwE+JYTi99QQ0JoUN1lsAY0wcWAIwJoVZAjAmhVkCMCaFWQIwJoVZAjAJJyLz3Dby0IRYAjAmhVkCMJ8TkRtF5F0R2SYiD4bnDWgQkZ+LyFYReU1ECsN1TxORjSLyvoiUdo4/F5GTReTPIrI9vE3n6q05EWP9V4R71BmHWQIwAIjIFOAaYI6qngYEgBsAH7BVVWcAbxDqdQbwBPB9VZ0G7IgoXwH8WlWnE+ovXxkuPx24jdD8DZMI9Zg0DhuUXYFNQswHZgKbwh/OWYQGlASB34frPAk8JyLDgTxVfSNc/jjwtIjkAmNVtRRAVVsAwvt7V1Urwq+3AROANxP+W5leWQIwnQR4XFXv+lKhyP/uUq+3vuO9Xda3RvwcwP73XMFuAUyn14Bvicgo+HyOua8Q+h/5VrjO9cCbqloPHBaRc8Lli4A3wvMeVIjIZeF9ZIpIdjJ/CdM/loUNAKq6S0R+BLwiIh6gHVhKaGKOU0VkC1BPqJ0AQkNOHwif4JEj9xYBD4rIT8L7sEVeXcxGA5peiUiDquY4HYdJDLsFMCaF2RWAMSnMrgCMSWGWAIxJYZYAjElhlgCMSWGWAIxJYZYAjElh/x9yklCXxYGpPQAAAABJRU5ErkJggg==\n"
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": "<Figure size 252x180 with 1 Axes>",
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQsAAAC1CAYAAABI4trjAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAMe0lEQVR4nO3dfYwcBR3G8e9jK2pLtRIKlrZaoXBSjbxIKFg1FxADSsA/VEBBo0ZirAq+BChqjP5l4jvRgA2gIBUEhFgRBcQeihGkb1igVgjyUqgWRIFiFIo//5i5urfs3f1u9nZnj30+SXM3szO7T9u752b2dueniMDMbDwvqDuAmU0NLgszS3FZmFmKy8LMUlwWZpbisjCzlOl1B+im2bNnx6JFi+qO0dJTTz3FzJkz647RkrNV08vZ1q5d+2hEzJnIPn1VFnvuuSdr1qypO0ZLQ0NDDA4O1h2jJWerppezSbp/ovv4NMTMUlwWZpbisjCzFJeFmaW4LMwsxWVhZikuCzNLcVmYWYrLwsxSXBZmluKyMLMUl4WZpbgszCyl1rKQdLSkzZLukXRWi9sl6Zzy9j9KOrjp9mmS1ku6pnupzfpTbWUhaRrwXeAYYDFwkqTFTZsdA+xb/jkVOLfp9tOATR2OambUe2RxKHBPRNwbEU8DlwHHN21zPHBxFG4BZkuaCyBpPvAO4PxuhjbrV3Ve/GYe8GDD8hZgSWKbecBW4FvAGcCssR5E0qkURyXMmTOHoaGhdjJ3zPbt252tAmfrnjrLQi3WNY9Ha7mNpGOBbRGxVtLgWA8SESuAFQADAwPRq1cu6uWrKjlbNb2crYo6T0O2AAsalucDDye3WQocJ+k+itOXIyRd0rmoZlZnWdwG7Cvp1ZJ2AU4EVjVtswp4f/lbkcOAxyNia0Qsj4j5EbGw3O/XEXFyV9Ob9ZnaTkMiYoekjwPXAdOACyPiTkkfLW8/D7gWeDtwD/Av4IN15TXrd7Ve3TsirqUohMZ15zV8HsCyce5jCBjqQDwza+BXcJpZisvCzFJcFmaW4rIwsxSXhZmluCzMLMVlYWYpLgszS3FZmFmKy8LMUlwWZpbisjCzFJeFmaW4LMwsxWVhZilTcm6IpAWSVkvaJOlOSad1P71Zf5mqc0N2AJ+JiP2Bw4BlLfY1s0k0JeeGlNfhXAcQEU9SDBqa183wZv2mzrIYbSbIhLaRtBA4CLh18iOa2bApOTdk543SrsBPgNMj4omWD+IhQ21ztmp6OVsVdZZFO3NDkPRCiqJYGRFXjfYgHjLUPmerppezVTEl54ZIEnABsCkivtHd2Gb9aarODVkKnAJslLShXHd2OVrAzDpgSs4NiYibaf18hpl1iF/BaWYpLgszS3FZmFmKy8LMUlwWZpbisjCzFJeFmaW4LMwsxWVhZikuCzNLcVmYWYrLwsxSXBZmluKyMLMUl4WZpYx5PQtJnx7rdl+lyqx/jHdkMWucP22pOmQos6+ZTa4xjywi4kudeuCGIUNHUVyY9zZJqyLirobNGocMLaEYMrQkua+ZTaLUcxaS9pb0M0mPSNom6aeS9m7zsSsPGUrua2aTKPsE54+Ay4G5wF7AFcClbT52O0OGMvua2STKXrBXEfHDhuVLyitzt6OdIUOZfYs78JChtjlbNb2crYpsWayWtJziaCKAE4CfS9oNICIeq/DY7QwZ2iWxL2U2Dxlqk7NV08vZqsiWxQnlx4+UH4d/sn+IojyqPH+xc8gQ8BDFkKH3Nm2zCvi4pMsonuAcHjL0SGJfM5tE2bJYDHwMeBNFOfwWODci/l31gdsZMjTavlWzmNn4smVxEfAEcE65fBJwMfCedh686pCh0fY1s87JlsVARBzQsLxa0u2dCGRmvSn7q9P15WBiACQtAX7XmUhm1ouyRxZLKKaZP1AuvxLYJGkjxdnC6zuSzsx6RrYsju5oCjPreamyiIj7Ox3EzHqbr2dhZikuCzNLcVmYWYrLwsxSXBZmluKyMLMUl4WZpbgszCzFZWFmKS4LM0txWZhZSi1lIWk3STdIurv8+PJRtms5SEjSVyX9qRw8dLWk2V0Lb9an6jqyOAu4MSL2BW4sl0doGCR0DMVl/U6StLi8+QbgdeVb4/8MLO9KarM+VldZHE9xqT7Kj+9ssc2og4Qi4vqI2FFudwvF1b3NrIPqKos9I2IrQPlxjxbbZAcJfQj4xaQnNLMRshe/mTBJvwJe0eKmz2XvosW6EYOEJH0O2AGsHCOHhwy1ydmq6eVslURE1/8Am4G55edzgc0ttjkcuK5heTmwvGH5A8DvgRnZx91vv/2iV61evbruCKNytmp6ORuwJib4fVvXaciq8pt9+Jv+py222TmESNIuFIOEVkHxWxLgTOC4iPhXF/Ka9b26yuIrwFGS7gaOKpeRtJeka6EYJAQMDxLaBFwe/x8k9B1gFnCDpA2Szmt+ADObXB17zmIsEfF34MgW6x+mmEA2vNxykFBELOpoQDN7Dr+C08xSXBZmluKyMLMUl4WZpbgszCzFZWFmKS4LM0txWZhZisvCzFJcFmaW4rIwsxSXhZmluCzMLMVlYWYpLgszS3FZmFnKlBwy1HD7ZyWFpN07n9qsv03VIUNIWkBxSb4HupLYrM9NySFDpW8CZ9A0HsDMOqOWa3DSNGRIUnbI0BIASccBD0XE7VKr8SL/57kh7XO2ano5WxVTbsiQpBnlfbwtcycRsQJYATAwMBCDg4PJh++uoaEhnG3inK17OlYWEfHW0W6T9DdJc8ujirnAthabbQEWNCzPBx4G9gFeDQwfVcwH1kk6NCL+Oml/ATMbYcoNGYqIjRGxR0QsjIiFFKVysIvCrLOm6pAhM+uyKTlkqGmfhZOdz8yey6/gNLMUl4WZpbgszCzFZWFmKS4LM0txWZhZisvCzFJcFmaW4rIwsxSXhZmluCzMLMVlYWYpLgszS1FE/1zCUtKTwOa6c4xid+DRukOMwtmq6eVsAxExayI71HUNzrpsjohD6g7RiqQ1zjZxzlaNpDUT3cenIWaW4rIws5R+K4sVdQcYg7NV42zVTDhbXz3BaWbV9duRhZlV1BdlMd6A5TpJWiBptaRNku6UdFrdmRpJmiZpvaRr6s7STNJsSVdK+lP573d43ZmGSfpU+f95h6RLJb24xiwXStom6Y6Gdanh5I2e92Ux3oDlHrAD+ExE7A8cBizrsXynUYxi6EXfBn4ZEa8BDqBHckqaB3wSOCQiXgdMo5h7U5cfAEc3rRt3OHmz531ZMP6A5VpFxNaIWFd+/iTFF/y8elMVJM0H3gGcX3eWZpJeCrwFuAAgIp6OiH/WGmqk6cBLJE0HZlBM06tFRPwGeKxpdWY4+Qj9UBatBiz3xDdjM0kLgYOAW2uOMuxbFJPq/1tzjlb2Bh4Bvl+eJp0vaWbdoQAi4iHga8ADwFbg8Yi4vt5UzzFiODnQajj5CP1QFi0HLHc9xTgk7Qr8BDg9Ip7ogTzHAtsiYm3dWUYxHTgYODciDgKeInEo3Q3l+f/xFDN59wJmSjq53lTt64eyGG3Acs+Q9EKKolgZEVfVnae0FDhO0n0Up25HSLqk3kgjbAG2RMTwUdiVFOXRC94K/CUiHomIZ4CrgDfWnKnZ38qh5IwxnHyEfiiLlgOWa860k4pR8BcAmyLiG3XnGRYRyyNifjke8kTg1xHRMz8dy0HYD0oaKFcdCdxVY6RGDwCHSZpR/v8eSY88+dogM5x8hOf9G8kiYoek4QHL04ALe2zA8lLgFGCjpA3lurPLOa82tk8AK8sfAvcCH6w5DwARcaukK4F1FL/tWk+Nr+aUdCkwCOwuaQvwRYph5JdL+jBFub173PvxKzjNLKMfTkPMbBK4LMwsxWVhZikuCzNLcVmYWYrLwnqOpMFefJdrv3NZmFmKy8Iqk3SypD9I2iDpe+W1L7ZL+rqkdZJulDSn3PZASbdI+qOkq4evnyBpkaRfSbq93Gef8u53bbhWxcrylZBWI5eFVSJpf+AEYGlEHAg8C7wPmAmsi4iDgZsoXi0IcDFwZkS8HtjYsH4l8N2IOIDi/RNby/UHAadTXINkb4pXulqNnvcv97aOORJ4A3Bb+UP/JRRvRvov8ONym0uAqyS9DJgdETeV6y8CrpA0C5gXEVcDRMS/Acr7+0NEbCmXNwALgZs7/reyUbksrCoBF0XE8hErpS80bTfW+wnGOrX4T8Pnz+Kv1dr5NMSquhF4l6Q9YOc1HV9F8TX1rnKb9wI3R8TjwD8kvblcfwpwU3ndji2S3lnex4skzejmX8Ly3NZWSUTcJenzwPWSXgA8AyyjuAjNayWtBR6neF4DirdBn1eWQeM7RE8Bvifpy+V9jPvuR6uH33Vqk0rS9ojYte4cNvl8GmJmKT6yMLMUH1mYWYrLwsxSXBZmluKyMLMUl4WZpbgszCzlf+9ThNfMS1oWAAAAAElFTkSuQmCC\n"
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "train_seq2seq(net, train_iter, valid_iter, 0.1, 10, lang, device, first_train=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(encoder, MODEL_ROOT + \"trans_encoder{}.mdl\".format(num_examples))\n",
    "torch.save(decoder, MODEL_ROOT + \"trans_decoder{}.mdl\".format(num_examples))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP0AAAC4CAYAAAAooOojAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi40LCBodHRwOi8vbWF0cGxvdGxpYi5vcmcv7US4rQAAF9ZJREFUeJzt3Xt8VPWZx/HPM5lMGDK5M0mQAFJCRMQb8YY3QGu9VFutdFutL3XbLdiuZcV7deu21lbrpeq63aptd1H7YlekitZLbWtB6HppiRbRVgmgUoJMJCEhEwIhybN/nJMwiSEZQk7m9rxfr7ySM+c3Zx5CvnPOnOdcRFUxxmQOX6ILMMaMLAu9MRnGQm9MhrHQG5NhLPTGZBgLvTEZxkJvTIax0BuTYSz0xmQYC70xGcaf6AIORGFhoVZWVia6jH61traSm5ub6DL6ZbUNTTLXBlBTU7NNVcODjUvp0JeVlbF69epEl9GvFStWMHv27ESX0S+rbWiSuTYAEfkwnnG2eW9MhrHQG5NhLPTGZJiU/kxvko92KZ0tnXQ0d/R8dbV2UXxmMQDbnt7Gjtd39MzrbO5EsoXpT053FnAPrH96PWWXlZF3VF4C/yXpy0Jveunc2Ul7pJ3OHb2DG54bJmtUFg0vNNDwTEOveZ3NnVSvrsaX42P9Veupe6Cu90J9MKtjFiLCtl9vI/JIhKyCLPwFfvwFfgJjA3vHhqHuP+vYfN9mco/IpfzSckq/UkpOec7I/iLSmIU+zWmXsnvLbrJLsskKZtH0chPcCmvuWENn895gH7X8KEZXjWbLQ1vYcPWGTyyn8JRCsiZm0fp2Kx//6mP8Bf6e4AamBOhq78KX42PM+WMYdfCovfPznWCjgEDVg1Uc8rNDEJH+C74UTrz7ROr/t56tj2xlw7UbaP1rK1N/MRVVRdsVX459Kj0QFvo0oF2Kdiq+bB9tG9rY8vAW2mrbaFvvfHW1dXHkS0dSdFoRHTs6oBY6DurAn+8nMC6Av8DfE6TizxTj/y9/z1q4O7iBg5y18YTrJjDhugn7rKXotCKKTiva53yff/DAZhdnM+6b4xj3zXG0vtuK+J03iOgbUdZ8eg3hL4Upv6yc/BPy9/3mYfbJQp9iOpo7qF9S7wQ6JthVP62i/LJy9mzfw+b7NhOcHCRYGaTojCKClUGCVUEAxpw3BvKgenZ1v8vPPSyX3MOS5wCU3Kl7a/EFfRR/tpjIoxE+eugjglVByi8tZ9y3xuHPtz/leNlvKslop9K0qql3qGvbKLu0jAnXTqCzrZN189YhASH4qSDBKUGKPl3E6KmjAcibkcepO09FstJvDZg7LZdpv5xGx44OPl76MVsf2cqmuzZRsbACgNa/tpIzIQd/yP6sB2K/nQTYvWU3O9ft7BXsvBl5TLx5IgBvnfkW2q69gp0zztmRFSgLcPz7xzNq/Kh+gy2+9At7X/58P2O/OpaxXx3Lnu17yBqdharyzhffYdeHuwjPdTb/C2cVZsTvY39Z6D2gqrANml5uYmftTtpq2/Dn+3tC/eZJb7Lrg10APcHOneZsxkqWcNQfjiIwLtBvsEWE4MHBkf0HJbHsouyen6seqiLySIT6JfVEHomQMzGHSbdNovyS8gRWmHws9MNEO7UnoGvPWwvPwV/4CwCSLRSdvnfnVuX9lfiCPoJTgv0Gu+CkgpErPE2ICIUnF1J4ciGV91ey7eltbH1kK76As+Nw95bdNDzbQPgfwmQXZg+ytPRmoT8Au7fspuG5Bhqea2DHazuY+eFMfDk+Sr9cSuOkRo743BEEK4OMmtA72GM+NyaBVae/rNFZlF1URtlFZT2PbXt6G7XfrKV2QS1jzh9D+WXlFJ1RFFc3Id1Y6Ieg8feNbLx+I9E3owDkTMghfGGYzmgnvhwf5ZeU827FuxTPLk5wpabbQVccRN4xeUQejRBZHOHjxz8mpyKH49YdR1YwK9HljSgL/SA6dnTQ+NtGGp9rpPxyZ+eQP89PVm4Wk26fRMm5JeQelmv94iQnIuQfm0/+sflMvmcyDc810Pp2a0/gaxfUEpwcpPTiUgLhwCBLS20W+n507upky0+30PBcA80rm9E9ir/QT8GsAgpnFZJ/fD5Hrzo60WWaIfIFfIQvCBO+wLneRFd7Fzte30HdA3VsuHYDxecUU35ZOSXnlvTsE0gnFnqc//TmVc10NHUQvjCML+Bj0+2byA5nU7GwgpJzS8ifmZ+Rn/8ygS/go/r1aqJvR4k8EiHyywgNzzQw+d7JjL9qPNqpaXU+qiehF5E8YAHwPjAbuFNV18fMPxQ4C2gDyoHv4fxarwa2AtXAj1V1kxf1AbTXt9P4QiMNzzbQ+GIjnS2dBA8JEr4wjPiE4949juzizN7Lm2lC00OE7gox6fZJbP/ddvKqnbP8IosjbLp9E3wR5685xXm1pg8DK1V1lYhsAuYD18XMn4/zRrBFRH4CFAAHA9mq+piIvAdcANw/XAWpKq1vtZJ7hPP5e+MNG9m6aCuBgwKUfrmUknNLerXVLPCZy+f3UXJ2Sc909phs/MV+aE9gUcPIk9Cr6kZgozs5CVjZZ8jLwJUicicQUdUmEXkXuF1ExgFHA88faB2drZ1sf2k7Dc86bbX2Le0cs/YYQtNDjL9uPOMWjCN0VMh2wpkBlZxdQsnZJaxYviLRpQwLUVVvFixSDtwI7ARuUdWOPvNvA04HrlHVV9zHTgR+APwJuLnvc9wx84B5AOFwuHrJkiW9B3QCWcBa4BpgDzAaOBY4ATgZCA3Xv3LfotEoodAIvNAQWG1Dk8y1AcyZM6dGVY8ZdKCqevoFTATu7fPY1cAEYBTwEM5n+EOBK9z5n8H5TD/gsquqqrRzT6duf3m7rr9+vb4+7XX94AcfqKrqnh17tHZhrTa+1Kiduzt1pC1fvnzEXzNeVtvQJHNtqqrAao0jk17tyJsB1KvqZqAFOFJEAu6bTDtwkvtGoCKyGJiJs35e4y7id8DXB3ud9jp4JfwKHU0dSLZQcGoBwSnOcen+PD+VP07Oa+Ibk0he7cirAy4RkQactfhNOHvzBbgLeAy41d3JN9F9bBRws4hMBYqBBwd7kfZOJfj5QiZ8oYyiM4rsnGpj4uDVjrwIcI87ucj9/lrM/GXAsj5Pa8Z5Y4jbtkLl1a/lUH3aoDf1MMa4UvqQg1FZsLRmc/d+AmNMHFI69KGA8EHDTmo+3J7oUoxJGSkd+ly/MDqQxdKazYkuxZiUkdKhF4Gzp4/l2bc+oq29M9HlGJMSUjr0AHOrK4ju7uDFd7YmuhRjUkLKh/74ScVUFAVtE9+YOKV86H0+4cIZFfzfhm3UNbUluhxjkl7Khx7gwhkVqMJTb9ja3pjBpEXoJ5SM5vhJxdazNyYOaRF6cHboWc/emMGlTejPOXys9eyNiUPahD43x289e2PikDahB+vZGxOPtAq99eyNGVxahd569sYMLq1CD9azN2YwaRd669kbM7C0Cz1Yz96YgaRl6K1nb8y+pWXorWdvzL6lZejBevbG7Evaht569sb0L21Dbz17Y/qXtqEH69kb05+0Dr317I35pLQOPVjP3pi+ki70IuITkWNE5MjhWJ717I3pzZPQi0ieiNwsIheLyMMiUtln/qEislBErhCR74qIuI+XA9cDH6vqmv6Wvb+sZ29Mb16t6cPASlVdDDwKzO8zfz7wuKo+6I4tcIN/C3Cfqn44nMVYz96YvTwJvapuVNVV7uQkYGWfIS8DV4pIIRBR1SZgNrAHuFxEvi8iZcNVj/XsjdlLvNqr7W6q3wjsBG5R1Y4+828DTgeuUdVXRGQhsF1VF7mf529U1Yv6We48YB5AOByuXrJkSVz1PFXbzjMb9nD3rCAlQe93ZUSjUUKhkOevMxRW29Akc20Ac+bMqVHVYwYdqKqefgETgXv7PHY1MAEYBTwEVAPfAi535wvwx8GWXVVVpfH6cFurTrzhWX3gpXVxP+dALF++fEReZyistqFJ5tpUVYHVGkcmvdqRN0NEKtzJFuBIEQmISMB97CTg76q6C1gMzARWATPc+aVAZDhrsp69MQ6/R8utAy4RkQactfhNwAKcNfhdwGPArSKyCWdL4C5VbRaRTSLyj8BY4NvDXdTc6gquW/oWNR9u55iDi4d78cakhLhC7wbxdSAL+FfgBVVdtK/xqhoB7nEnu8e9FjN/GbCsn+fdHU89Q3XO4WP5t2feYWnNZgu9yVjxbt43q+pfgRuAr+F8Fk851rM3Jv7QF4jICcAmVY0CDR7W5Cnr2ZtMF2/o1wBnAT8SkaNwPnOnJOvZm0wXV+hV9Q1V/a6qNgOtwIPeluUdn0+YW23n2ZvMFVfo3ePjp4nIPwHfAL7jbVnesvPsTSaLd/P+LeBd4CxVvRpY711J3htfPJoTPmU9e5OZ4g39JOABYImIhIA870oaGXOrx9t59iYjxRv6e3EOoFkClAN/9q6kkXH29HI7z95kpHhDr8AsEfkxziG0f/KupJGRm+PnnMOtZ28yT7yhvxbnGPqHgWbgOs8qGkHWszeZKN7Q/11Vn1TVd91DaOu8LGqkHHdwMeOLrWdvMku8oe97EnHucBeSCHZtfJOJ4g39RyKySERuFZH/Bj7ysqiRZD17k2niPSLvOeBfgKeBhZ5WNMKsZ28yzT5PrRWRH+Fc1KJvEgSnbfdrD+saUXOrx3PtE2vsPHuTEQY6n/55Vb2hvxkiMtOjehLi7Onl3PL023aevckI+9y8V9WXB5j3qjflJIb17E0mSbo73CSK9exNprDQu6xnbzKFhd5lPXuTKSz0MaxnbzKBhT6G9exNJrDQ92Hn2Zt0Z6Hvw86zN+nOQt+H9exNurPQ98N69iadeXUDyzwRuVlELhaRh0Wkss/8Q0VkoYhc4V5pV2Lm5YvIb7yoK17WszfpzKs1fRhYqaqLgUeB+X3mzwceV9UH3bEFAG74LwTqPaorLtazN+nMk9Cr6kZVXeVOTgJW9hnyMnCliBQCEVVtch8/D3gB6PKirv1hPXuTrry6VTUiUg7cCOwEbomdp6pPiUg1TsCvccdPA5pUdWvM1n5/y50HzAMIh8OsWLHCk/oBphb7eOyPtRwmmxmopv5Eo1FPazsQVtvQJHNt+0VVPf3Cuf/8vX0euxqYgHP324dw7mF/Pc6bxI1Ajft99EDLrqqqUi89sfrvOvGGZ/XP7zfs93OXL18+/AUNE6ttaJK5NlVVYLXGkUmvduTNEJEKd7IFOFJEAiIScB87Cedim7uAxcBMVb1TVe9Q1TuAte7PO72oL17WszfpyKsdeXXAl0TkcuB7wE3AApxLbgE8BtwqIl8HznCnk4717E068uQzvapGgHvcyUXu99di5i8Dlg3w/Mu9qGso5lZXsLRmMy++s5Xzjx6X6HKMOWB2cM4grGdv0o2FfhDWszfpxkIfB+vZm3RioY+DnWdv0omFPk52nr1JFxb6OFnP3qQLC32crGdv0oWFfj/YefYmHVjo94P17E06sNDvB+vZm3Rgod9P1rM3qc5Cv5+sZ29SnYV+CKxnb1KZhX4IrGdvUpmFfgisZ29SmYV+iKxnb1KVhX6IrGdvUpWFfoisZ29SlYX+AFjP3qQiC/0BsJ69SUUW+gNkPXuTaiz0B8h69ibVWOgPkPXsTaqx0A8D69mbVGKhHwbWszepxEI/DKxnb1KJZ7eqzjQXzqjgvt/X8tQbm7nytCmJLqfHjl17qI1EWV/fwrpIlNr6KOvrdhKsWdEzJvY23H1vyB17h26JmTvQnbv3tbxey+rz/O5lt7S0kbf2jyhOC7S7E9rzne7pT7ZI947RAZ/T88xB5se+hKLsatvF4kNbmFKW94nXTiWehF5E8nBuWPk+MBu4U1XXx8w/FDgLaAPKcW5yWQB8DYgAZwLfV9V1XtTnhdie/T/PqRzx12/euYfa+hZq66Osi7Swvj5KbSTK1h27esbk+H1UloYYG/JRXpbvPNjnDztWrz/6fYzrm73Yyd7z4nuOtEcZk5fj/Nz9WM8bhPSa/uT8vW8ePWOk9+N84rnS77J6v1k5U5FIhGAgi1Tn1Zo+DKxU1VUisgmYD1wXM38+zhvBFhH5CU7gDwV+q6prRaQOmAdc61F9nphbPZ5rn1jjac++aWe7u8ZuoTbme33L7p4xwewsKktDnDi5hClleUwpDTGlLERF0WiyfMKKFSuYPXuGZzUeCKe2YxNdRr9WrFhBRdHoRJdxwLy6a+1GYKM7OQlY2WfIy8CVInInEFHVJuDVmPklwF+8qM1LZ08v55an32ZpzWbOKjmwZTW2trMu4qy510f2bppvi+4N9+hAFlNKQ5wyJcyUshBVZSGmlOYxrjCIzzfA9rfJaOLV4aMiUg7cCOwEblHVjj7zbwNOB65R1VdiHh8N3A5cr6q76UNE5uFsBRAOh6uXLFniSf1D9fO1u1m9tYMfHqcUF4QGHKuqtLRDXbSLumgXW1q72OL+3NK+d9yoLDgo5GNcyOd+Fw4K+SgeJfgG+nC9D9FolFBo4NoSxWobujlz5tSo6jGDjfMs9D0vIDIRuEpVF8Y8djWwFKgH7gceVtUacT483QT8VFUbB1v2IYccou+9955HlQ/Naxsb+PLDrzHviBxuuvjTgBPuj6O7WR+J9qy9a+uj1EZa2L5zT89z83L8VJaFqCrNY0pZiMrSEFVleYwtGNVr59iBcjahZw/b8oaT1TZ0IhJX6L3akTcDqFfVzUALcKSIBABUtR04CbhXVVVEFgMzgRqcHXmL4gl8suru2T+/cRctT65lvbtzrSk23KP8VJXlcdb0cipLnc/cVWV5lOXnDGu4jemPVzvy6oBLRKQBqMZZey/A2Sl6F/AYcKu7k28icJeIfAG4CDjZ/cPfrarzParPMz6fcPFxE/nRb97l+bUfUVUW4uzpY3s+b1eVhQjnWbhN4ni1Iy8C3ONOLnK/vxYzfxmwrM/TnnS/Ut78Uz/F+PZNfPaM2RZuk3TsiDwP+HxCKCAWeJOULPTGZBgLvTEZxkJvTIax0BuTYTw/OMdLItICJNfROXuNAbYluoh9sNqGJplrAzhEVQc9BTDVT619L54jkBJBRFZbbfvPahs6EVkdzzjbvDcmw1jojckwqR76hxNdwACstqGx2oYurvpSekeeMWb/pfqaPqmJHYdrklDKhl5ErhSRr4jId0SkONH1dBMRn4hcIiLLgBMTXU8sESkUkWvc+h4TkapE19RNRPJE5GYRuVhEHhaRkb/Q4CBEJF9EfpPoOvoSkSdF5Ofu1zcGG5+SLTsRmQqUqep/iMghwELgOwkuCwBV7QJ+6f7RJtuaPpmvQzjYdRUTyt1quxDnwi/J5hlVXRTv4JQMPXAKzkU3UNX33It2mEGoatJehzCO6yom2nnAC8CsRBfSj8NEZCGQCzygqs0DDU7VzfsxwI6Y6exEFZKK3OsQngI8kehaYolIuYjch7NF8kKi6+kmItOAJlVNyvuWqep1qnov8Bxw92DjUzX09UAh9Gx2tQ883HRzf18Lge/1d+HRRFLVrap6FfAQzhWWksW5wIkiciNwuIjc6L5xJhVVfRNnK2lAqbp5vwq4HOdKO9OAuA4/NECSXoewv+sqJrikHqp6Z/fPIjJVVe9IZD2xRGQOzu/tHREpAgbdGknJ0KvqOhHZIiKXAwfjXFE3KbhrgM/iXBtwlIgEVPUPCS4LgCS/DmF/11U0g3sT+KqITAemEsfvzQ7OMSbDpOpnemPMEFnojckwFnpjMoyF3pgMY6E3JsNY6I3nRGS2e1tykwQs9MZzqroCsPMjkoSF3pgMk5JH5BnviMgCoAnnbsIrgQuALJyzGqcAm1T1IREJ4dyJeANQBryuqq+LSBjnlN11wGTgZ6ra4C77TOBo4HPAWaq6AzPiLPSmh4icBgRV9d/dE3MeBX4BnNp9vraILBWRpcDVwP+o6tvu2KdF5PPA7cBNqlovIuOAcqABZ6tyjaq+KCJdwEzgxZH+NxoLvemt+3yB83HW7h+6j3fFjNmAsxVwJPBdAFVVEWkDCoCJqlrvPl6Hc0w9QFfMqan1OBfNMAlgn+lNrL8BLaq6TFV/hRtqel8BaDLOxS7WAlXQc7puEGgGtonIlO7BIpI/AnWb/WAn3JgebnhvBQLAZpxgA1wBLMc5o/FvqvqoG+YFOG8AZcBrqvqqiEwAfgjUAo3A48AE4BHgWzhX6/kB0Al8W1VbRuZfZ7pZ6M2ARGQ2cLKq3pboWszwsM17s08ikotzRd9pIjI50fWY4WFremMyjK3pjckwFnpjMoyF3pgMY6E3JsNY6I3JMBZ6YzLM/wNiCIFwmSgu1gAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 252x180 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "train_seq2seq(net, train_iter, valid_iter, 0.00001, 5, lang, device, first_train=False)\n",
    "torch.save(encoder, MODEL_ROOT + \"trans_encoder2{}.mdl\".format(num_examples))\n",
    "torch.save(decoder, MODEL_ROOT + \"trans_decoder2{}.mdl\".format(num_examples))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loss 1.572, PPL 4.848, 4.0 min on cuda:0\n",
      "loss 3.388, PPL 30.146 \n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPcAAAC4CAYAAAA/gnrqAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi40LCBodHRwOi8vbWF0cGxvdGxpYi5vcmcv7US4rQAAGXJJREFUeJzt3Xl8VPW9//HXZ7KTPSFsYQ8JIeyCoiAKapVWEaTen9fl3tpyxWpb1KKWRRGxbnj90Va9ba1tuW63Upu6s+gVVBBUcAsRCFtkB1kCBMhC5nP/mEnIniFkcmYmn+fjkQeZ+Z45886Qd+bMmTPnK6qKMSb0uJwOYIzxDyu3MSHKym1MiLJyGxOirNzGhCgrtzEhysptTIgK98dKRSQemApsA8YA81R1c7XxMcBtwDHvVU+p6lf+yGJMWyX+OIhFRHoD6ar6kYhcCExQ1XuqjY8BClW1sMXv3BgD+OmZW1W3Alu9F3sBH9ZapBS4QUTAU/KX/ZHDmLbML8/cACLSCZgOnABmq+qpBpabD+Sq6kf1jE0BpgBER0cP6969u1+yni23243LFZi7Lyxb8wRytoKCggOqmtbkgqrq1y+gBzC/kfErgHuaWk9WVpYGqmXLljkdoUGWrXkCORuwRn3onl/+NInIOSLS1XvxGDBYRCJFJNI7/jPxbpMDGUC+P3IY05b55TU3sAu4SUQOAsOAmXj2ngvwBLAdmCQiCUAssMhPOYxps/y1Q20f8KT34gLvv6urjb955us8+1zGtCX+euZucboFVvVYhSva5fmKctHl1i50ntyZ8qJyCqYUVF1fuUzq1akkjU6ivKicfS/uO31b73JxQ+KI7hZNxYkKTm45WeO2rmgXrnYuXOGBuVPF+EZVcZ904y5zo2WKlivuMjdh8WFEto/EfcrNsc+OoWVatYy7zA0RTic/e0FT7uNRStnwGNIiI3GXuHGXuJEoz8t2d4mb4+uOV13vLnHjLnUT1S2KpNFJlO0uY/MvNtdZZ9/n+tJ5cmeOrzvO5yM+rzOe87ccOlzXgaIPilg3cV3N4ke7yHw6k8RRifAN5P0mr85416ldiekdw/H84xxacqjOePKlyYQnhlO2r4zSnaW4ol1IlFSNRyRHIGFSJ5eT1O395a+AsNgwAEr3llJxrKJGQSRKiB8SD8Dh5YcpP1BeYzyiQwQkeda58+mdlO8rr1Gu2P6xpN+eDsCGyRso/67mePKlyfR6sBcAnw38jFNHT51ef7nS6cedyPxtJrjho9g6b8TQ7Z5uZMzLwH3czRcjv6j7g/7JDw9eKwuach9JVu4fcZgld15ETGRYjbGoTlGc9815Dd62Xd92jNw/sqr0lX8AortHAxCTEUP/V/vX+eMQNyQOgIiOEXT8t441busuceNq531WPwklhSVV12up4i5x0/GGjsT0juHop0fZMm1LnVzD84YTlxjH/oX72Ty17h+fEdtGENMzhu3ztlM4t7DOlsnQFUOJSI5gz5/38N0/vquxVeKKdtHnqT4A7PuffRQtL6rxy48L+v+tPwCFcws5tOhQjfKEJ4Uz7JNhAORfl8/BNw7iLveUGiAmM4YRBSMA+OZfv+HIB0dqZI8bGsfwz4cDsOXuLRSvLa4xnnhRIjzo+X7X07s4WeDZcpJIQSKElHEpVeUu2VbCqaJTSKTgijz981eKHxEPFSARUrVMwsgEACRM6P14byTCc33leOyAWABcsS4GLR5Udb+Vy6zdu7bB36dgETTlbh/tYvuhE/zu/U38alz2Gd1WwoTItMgGxyNSI0j7YcNvG8Zmx5L5u8yG72AYnDvt3AaHO/5bR9ImpdX54xLTJwaA1PGpRPeMrvnHpcRNRHvPtmH8sHjSb0uvM+6K8vyCu0vclB8orzOe+Ywnc/EXxRx842CNX/7KZ10AiRTCEsKIiIyoGg9POf2rkXJFCtE9oqvGJEKqsgH0mNmD8lvKT49HCuHJp2+f81IO7nJ3jXK5Ylys/GolAOfln9foFsqQ94c0/NgD2c81/vvQ/d6Gj49whbtIuSKl7sDyRlcZFIKm3NHhMGF4V579cCtXD+5Cv84JTkfymSvchSvRBYn1j8f0jCGmZ0yDt0++NJnkS5MbHE//WTrpP0tvcDxjXgYZ8zIaHO8xvQc9pvdocLzzTzo3OAaQcnk95aimXd92jY4H2kuPUBFUe4tm/qAfSTERTM/No8Jtu8+NaUxQlTupXSSzx+fw1Y4iXlhV6HQcYwJaUJUb4OrBXbg4K40nlmxkd9FJp+MYE7CCrtwiwq8nDsCtMPv1dZXHpxtjagm6cgN0S2nHL7+XxXvr97N43V6n4xgTkIKy3AA/HtWT/l0SeOCNfI6cLHc6jjEBJ2jLHR7m4rFJgzhQXMq8xRucjmNMwAnacgMM7JrIT0b14qVPtvNZ4SGn4xgTUIK63AB3fS+L9KQYZuTmUXqqwuk4xgSMoC93bFQ4v544gM37i/njB1ubvoExbUTQlxtgbHYHxg/uwtPvb2bLd8VN38CYNiAkyg0w+6ocoiNczMjNw22HphoTOuVOi49i1pX9+HTbIf6+dofTcYxxXMiUG+D/De/GiF4pPPz2er47Vup0HGMcFVLlFhEemTSQknI3c9/6xuk4xjgqpMoNkJEWx88v6cObX+1m2Yb9TscxxjEhV26An16cQWaHOO57bR3HS+ud6MSYkBeS5Y4Md/HopIHsKjrJ/HcLnI5jjCNCstwAw3umcOOI7vxl5Tbydh5p+gbGhJiQLTfAveOyaR8XxfTcrzlV4XY6jjGtyl9zhcWLyCwRuUFEnhWRPrXGU0XkARG5UURu90cGgMSYCB68uj/5u4/y15WF/robYwKSv56504AP1TPv9vPArbXGpwEvq+pLQLqINHLe4LMzbkAnLuvXkf//bgE7Dp3w190YE3D8Um5V3aqn59vuBXxYa5FBqrrJ+/2XwCh/5ADPe99zJ/THJXDfa3ZaJtN2iL9+2UWkEzAdOAHMVtVT1caWqurl3u8vBwar6hP1rGMKMAUgLS1t2MKFC5ud591vy3lpfRk/HRTF+V1a9nTtxcXFxMXFteg6W4pla55AzjZ27Ni1qjq8yQV9mcT7bL6AHsD8Wte9Ve3764B/b2o9WVlZZzVh+akKt1799Ao9Z+5SPXy89KzWVVsgT9Ru2ZonkLMBa9SH7vlrh9o5ItLVe/EYMFhEIkWkck6fL0Skcg6YocAKf+SoLswlPDZpIEdOlvPIO+v9fXfGOM5fO9R2AdeJyM14pnubCUwF7vCOzwduEJEfAdtVtVXOstCvcwK3XNSbhWt28vGWA61xl8Y4xi9zhanqPuBJ78UF3n9XVxs/BMz2x3035Y5LM3knbw+z/rmORXeMJjoirOkbGROEQvoglvpER4Tx8MSBbDtwnGeW1Z0215hQ0ebKDXBhZnsmnZPO75dvYePeY07HMcYv2mS5Ae67Mof46HBm5H5tp2UyIanNljslNpL7r8rh8+1FvPTpdqfjGNPi2my5Aa4Zms6Ffdozb9EG9h4pcTqOMS2qTZdbRHj4mgGUVbiZ80a+03GMaVFtutwAPVJjufOyLBbn72VJvs0YakJHmy83wH+M7kV2p3geeD2fYyU2Y6gJDVZuICLMxWM/HMS+YyX855KNTscxpkVYub2GdEviRxf05PnV3/L59sNOxzHmrFm5q7n7ir50Sohmxj/yKDtlp2Uywc3KXU1cVDgPTRjAxn3H+NNHNmOoCW5W7louy+nIDwZ24rf/u4ltB447HceYZrNy12PO+P5EhbuYmZtnp2UyQcvKXY8OCdFM/342q7Ye5NW1O52OY0yzWLkbcP253RneI5mH31nPgWKbMdQEHyt3A1wu4dFJAzleeopf24yhJghZuRuR2TGe28b04bUvd/NBwXdOxzHmjFi5m3D7mAx6p8Uy6595nCizGUNN8LByNyE6IoxHrhnIzsMn+e17m5q+gTEBwsrtg/N7p/Kv53bjuRXbWLfLZgw1wcHK7aMZ3+9HcrtIZuTmUWGnZTJBwMrto8R2ETwwPoe8XUdY8HGh03GMaZJP5RaRH4tIjogMFJFXvJMNtDlXDerM2L5pPLl0IzsP24yhJrD5+sx9RFW/AX4FTAai/RcpcIkID00cgCrMfj3fDk01Ac3XcieKyPl4pv4pBg76MVNA65rcjmmXZ/H+hv28nbfH6TjGNMjXcn8FjAMeF5EhQOfGFhaRJBGZJiI3icgLIpJVa3yMd/P+Oe/X4ObFd8bNI3syMD2ROW98w5ETdlomE5h8mitMVT8HPgcQkePAH5q4ST9gqarmicguPHNs311rmV+pauGZxQ0M4WEuHp00kAnPrOSxxeu5IsXpRMbU5VO5RWQOsBAYCeQAx4H7G1peVVdVu5gKfFlrkVI8s3wCFKrqy75HDgwD0hOZfGEvnv1wKz3Pi2aM04GMqcXXWT6/BjYAc1X1Wu/Uu00SkXbAaODe6td7y7/Ku8x8Edmhqh/Vc/speJ71SUtLY/ny5T7GbR3DIpX2McJf8k7SO2kZES5xOlIdxcXFAfe4VbJs/uVruXsBTwELRSQOiG/qBuJ5Wr4LeFBVG/vM5GLgfKBOuVX1WeBZgL59++qYMWN8jNt6orrt5+a/fsa6inTuuiSr6Ru0suXLlxOIjxtYNn/zdYfafOAJVV0IdAI+8+E2k4EF3rm4EZFIEYn0fv8zb/kBMoCgne5jTN8OnN85jP9avpnN+23GUBM4fH3mVuBiEZmKZ8/5840tLCKTgOuBC70dLgU2AQI8AWwHJolIAhALLGpW+gBxQ3YU64vKmZGbxytTLsAVgJvnpu3xtdx3A1vwbCJnA/cA8xpaWFVzgdxGxt88g4wBLyFKmHVlP+599Wv+9tkObhjR3elIxvi8Wb5DVXNVdYOqvgbs8meoYPQvw7pyfu8UHl20nv1HbcZQ4zxfyx1X63JsSwcJdiLCI9cMpPSUmwfftNMyGef5Wu49IrJAROaKyF8BO+6yHr3T4ph6SR/eztvDe9/sczqOaeN8Kreqvg3cAbyO5+0t04ApF2WQ1TGO2a+vo7jUTstknNPgDjUReRw4B8+e8hpDeN4OC6mdYi0lMtzFo5MGce0fPubJpRt5YHx/pyOZNqqxveXvqOqv6hsQkQv8lCckDOuRzE0jerDg40ImDElnSLckpyOZNqjBzXJV/aCRsVUNjRmPe8b1pUN8FDNy8yivsBlDTeuz0yz5SUJ0BA9ePYD1e47y5xXbnI5j2iArtx+NG9CJy3M68pv3Cvj2oM0YalqXldvPHpzQn3CXi/teW2enZTKtysrtZ50TY7h3XF8+2nSA1760A/tM67Fyt4IbR/RgaPckHnprPYeOlzkdx7QRVu5WEOadMfToyXIefnu903FMG2HlbiXZnRK49eLe/OPznazcfMDpOKYNsHK3ol9ckknP1HbM/GceJeUVTscxIc7K3YoqZwz99uAJfve/NmOo8S8rdysb2ac91w7ryrMfbmX9nqNOxzEhzMrtgFk/6EdCTITNGGr8ysrtgOTYSGZflcOXO4p4cfW3TscxIcrK7ZAJQ7owOrM98xZvYHfRSafjmBBk5XaIiPDwxIFUqNqMocYvrNwO6p7ajrsuy+K99ftYkr/X6TgmxFi5HTb5wl7kdE5g9uv5HC2xGUNNy7FyO6xyxtADxaXMW7zB6TgmhFi5A8DgbkncPLIXL67ezprCQ07HMSHCyh0gpl2eRXpSDDNy8yg7ZadlMmfPL+UWkSQRmSYiN4nICyKSVWs8VUQeEJEbReR2f2QINrFR4Tw0sT+b9hfzxw+2OB3HhAB/PXP3A5aq6ovAX/DOsV3NNOBlVX0JSBeRTD/lCCqXZHfkykGdeer9zWz5rtjpOCbI+aXcqrpKVfO8F1OBL2stMkhVKz858SUwyh85gtED43OIinAxMzfP3vs2Z8XXWT6bRUTaAaOBe2sNRVb7/gjQs4HbT8H7rJ+Wlsby5ctbPmQLKC4ubtFsP8xwsSD/EA+99B4XdY04q3W1dLaWZNn8TFX98oVnZpJZQEo9Y29V+/464N+bWl9WVpYGqmXLlrXo+ioq3Povv/9YB81ZovuPlpzVulo6W0uybM0DrFEfOujPveWTgQWqeghARCJFpPIZ+wsRyfZ+PxRY4cccQcflEh6ZNICTZRU89JbNGGqaxy+b5SIyCbgeuFBEAEqBTXiezZ8A5gN3isgWYLuqbvVHjmDWp0M8t4/N4DfvbeKac9IZ27eD05FMkPFLuVU1F8htZPwQMNsf9x1KbhuTwZtf7ea+f67j3V9eRLtIv+4iMSHGDmIJYFHhYTw6aRC7ik4y/90Cp+OYIGPlDnDn9Urh+vO68+cV21i364jTcUwQsXIHgenfzyY1LorpuV9zymYMNT6ycgeBxJgI5ozvz7pdR1nwcaHTcUyQsHIHiR8M7MSl2R14cmkBOw6dcDqOCQJW7iAhIsydOAARbMZQ4xMrdxBJT4rh7sv78kHBd7z59R6n45gAZ+UOMj8a2ZPBXROZ+2Y+RSdsxlDTMCt3kPHMGDqIwyfKefQdOy2TaZiVOwjldEngltG9eWXNDlZtOeh0HBOgrNxB6o5LM+me0o5ZNmOoaYCVO0jFRIbx8DUD2HrgOP+1bLPTcUwAsnIHsdGZaUwams7vP9hCwb5jTscxAcbKHeRmXdmPuKhwZuTm4bYZQ001Vu4glxoXxX1X5rD228O8/Ol2p+OYAGLlDgGTzklnVJ9UHl+0gX1HS5yOYwKElTsEVM4YWlbhZs4b+U7HMQHCyh0ieraP5Y7LMlm0bi9LbcZQg5U7pNwyujfZneKZ/Xo+x2zG0DbPyh1CIrwzhu47VsKTS+20TG2dlTvEDO2ezI8u6Ml/ryrki+2HnY5jHGTlDkF3X9GXTgnRzMjN45S9991m2blyQ1BcVDhzJwzglufX8JyEsYFNVWPe88jXUf1qQepcL2ewbN11V1um2vVbCsvZtnJbPeuXeu6n/gBnmsuX5RHYvPsUY2r/IEHGyh2ivpfTkeuGd+OVNTtYvSeAX39vCMwZVWLCYabTIc5Sq5RbRETtvECt7vFrB/G9lINcfPHFVH/0ldMXGvpf8WV5rbG8NnB99Qs117lixUpGjRpV37BP61PqD+PL8k39HJ+sXk2w81u5RcQF3ABci2cKoZW1xnsCzwCV5wt6WVXf91eetircJUSEBeaulbhIITk2sukFHbA5JjAfszPht3Krqht4UUT6UOslUzVPqOpyf2Uwpi1z8jV3BfB9ETkXOA78wfsHwRjTAsTfL4VFZA7wnqo2OE2viNwJHFTVF2pdPwWYApCWljZs4cKF/ozabMXFxcTFxTkdo16WrXkCOdvYsWPXqurwJhf0ZRLvs/kC5gAXNrFMX+CZxpbJyso660nL/SWQJ2q3bM0TyNmANepD91p1s9y7ky1KVU+KyH8AL6hqKZABNPpxpoKCgmIR2dgaOZuhPXDA6RANsGzNE8jZ+vqykD/3lrcDrgSGAdEiEgmkAhcDPwe+Bq4XkVN4yv1oE6vcqL5sijhARNZYtjNn2ZpHRNb4spw/95afAP7u/aru797xT4FP/XX/xrR1wf9mnjGmXsFU7medDtAIy9Y8lq15fMrm97fCjDHOCKZnbmPMGbBytxBp6LOUxjgkKMotIj8XkRtF5H4RSXE6TyURcYnITSLyGjDS6TzViUiSiEzz5ntBRLKczlRJROJFZJaI3CAiz3o/fxBQRCRBRBY7naM2EckVkee8X7c1tmzAf55bRLKBjqr6tIj0Be4C7nc4FuDzh2Oc0g9Yqqp5IrILz2G8dzucqVIa8KGqfiQi24FbgXsczlTFuxX2Q2C/01nq8YaqLvBlwYAvNzAaWAugqhtF5ByH8wQFVV1V7WIq8KVTWWpT1a3AVu/FXsCHDsapz3hgEZ4DrgJNfxG5C4gFnlLVIw0tGAyb5e2Bo9UuRzgVJBh5jxQcTd2DiRwlIp1E5Dd4tjAWOZ2nkojkAEWqGpAnf1fVe1R1PvA28J+NLRsM5d4PJEHV5lKZs3GCh/fxugt40HsMf8BQ1b2qeifwRzwn8wgUVwEjRWQ6MFBEpnv/QAYUVf0Cz1ZPg4Jhs/wj4GYgF8gBfDqu1gAwGVigqoecDlKd96XVflXdCRwDBjscqYqqzqv8XkSyVfUxJ/NUJyJj8Txu+SKSDDS6dRHw5VbVAhHZLSI3Az2B3zqb6LT6PhyjAXKqKBGZBFwPXOh9l65UVW91NlWVXcBNInIQz2MX7OcibC1fAD8RkQFANk08bnaEmjEhKhhecxtjmsHKbUyIsnIbE6Ks3MaEKCu3aTEiMkZE5jW9pGkNVm7TYtQzwYQdHhwgrNzGhKiAP4jF+IeITAWKgB54PrhxDRCG50M6mcB2Vf2jiMQBU4EtQEfgE1X9RETS8HzSrADP2Wv/pKoHveu+AhgKXA2MU9WjmFZn5W6DROQSIEZVf+c9/vx54M/ARZUfJxSRV0XkVeCXwP+o6jrvsq+LyAQ8p6Keqar7RSQd6AQcxLM1+JWqLhERN3ABsKS1f0Zj5W6rKg+XnYjn2fpb7/XV52rbgudZfTCeWWNQVRWRk0Ai0ENV93uv34XnkFIAd7VPVO3H89lt4wB7zd02rQeOqeprqvoPvOWl5gknMvB85joPyIKqT5nFAEeAAyKSWbmwiCS0Qm5zBuzY8jbIW9K5QCSwE0+BAX4KLMPzAZ31qvq8t7RT8RS9I7BaVVeJSHfgEWATcAh4BegO/DfwCzwnh3gYz2yuM1T1WOv8dKaSldsAnveo8UzY+Guns5iWYZvlBhGJxXOCxxwRyXA6j2kZ9sxtTIiyZ25jQpSV25gQZeU2JkRZuY0JUVZuY0KUlduYEPV/SUlc2m0WG5oAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 252x180 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPQAAAC4CAYAAADUtcHpAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi40LCBodHRwOi8vbWF0cGxvdGxpYi5vcmcv7US4rQAAGXRJREFUeJzt3Xl4VIW5x/Hvm5kkkD1ACGvYCYEEFFAropK64ApCa6uVe7V1vWq11g1wKS6ItbV2ub22XO11aeu9tA2IWhUVUrVQFUQNiEGWsAlE9kwSQpJ57x8zCWFLhpCZc2byfp5nHpKz/hLynnPmzDnvEVXFGBMb4pwOYIxpO1bQxsQQK2hjYogVtDExxAramBhiBW1MDLGCNiaGWEEbE0OsoI2JIVbQxsQQr9MBmpORkaEDBw50OsZRVVZWkpyc7HSMo7JsrePmbMuWLduhqlktTefqgs7Ozmbp0qVOxziq4uJixo0b53SMo7JsrePmbCKyIZTp7JDbmBhiBW1MDLGCNiaGuPo9tK6DpScvxZvpxZvpJb5TPL3v6k1SbhLV66up+LDikHHeTC/eDC8SJ05HN8YRri7oqnhFusXjr6in6osq6nbX0f2G7gDsKd5D6Q9Kj5hn9CejSRmRwvY/bWfTk5sOLfZMLzn35hDfKZ6q1VXs37D/kHHedNsYxCr/AT8SL4gINVtrqNlUQ31FPfW+wEvrFXKcTnniwlLQIpIK3AasB8YBTwBpwDlAJfClqr7V0nJ2pCsvXO3nV1eMPGJc1rezSDs1jdrdtdTtrmt8JeYkAuBJ9ZDQI4G63XVUrqykbncdtbtq6X1nbwC2v7idDY8eduJQYOyesXjTvGz+9WZ2vrLzYLFneonPjKf3Pb0REdgEFR8fPELwptnGoK346/yBYmtScPW+elJHp+JN8+Ir8bF7wW7qffXUVdQ1jh/45EASshPY9vw2Nj6x8eC8FfVorTKmfAwJWQls+e0WNs7ceMg6PSkeeMWhH7gNhWsPnQW8q6rvichG4EYC278rVdUvIi+JyEJVrW9uIemJwsuffMWkk3syLrfrocFTvXiHHTt+lwld6DKhyyHDmrZb6nFTDzLPz2zcENTurqVuVx2eVE9wYqj31bN/0/7GaSRByLk3uBl/Hpa9s+zgwuOgY/+OnPblaQCUPVJG5YrKxg2BN9NLYk4i2VdkA1C9rhriID4zHk+aJ7CRiEL+On9j4XjTvXhTvdTuqmXv+3sPKcZ6Xz1Z3w58jFqxvIKyh8oOHV9RT96f8sgYm8HXf/2aVVeuOmJdI/81krTT0qj4sIK1d60FIC4pDk+KB0+Kh7p9dSRkJ+DN9JI8NLlxuCfFgyfVQ1xi4JRR9lXZpI9Jx5PqOWSaJaVLIveLC5OwFLSqrgPWBb/tB7wPXK2q/uCwXUBvoKy55aQnCr2ykrl/3goW3HEWSQknFrdp0ST2TCSxZ+Ixp+11ey963d6r8XtVxV/tPzjB92DYrcMCxb4rsEEQz8HlH9h2AN9yX+PGgnpIzk9uLOhVU1axb8m+wMRx4M30knF2Bvl/ywdg3bR11O2tO+QtQ8dBHck4MwOAmq01jX+IoW4MVBURQeuVqi+rjiio5KHJpIxIoXZPLRtnbTxiD9n9hu50/XZXqkqrWD52OfW+evz7D/5Ocp/NpfsPulP9ZTUrJq44Yv1JQ5KgK/hr/Oxftx9PqgdvupfEnol4UgJfA6SOTGXAUwMCw1K9jT9nUl4SAF2v6krWd7LwJHkO+Z03ONrGvKnkvGSS845yAcmR7+CijoSrSaCIdAOmAlXA08BMVf334LjHgLmq+tFR5rsBuAEgKytr1CO/e4lZH+7ngr7xXDEkISxZW8Pn85GSkhLaxApUA/uBTsFhy4BtgA+oCL6ygCnB8XcQ2CT6gIaaGQPMDH49GdgNeIDU4Ots4FrwVfhImZYSWGc1gf+BamACcEvw64uOkvPfgB8Ae4DvAh2BpOC/HYFJBN407QGeazK8YZoCAsdh1cDGJuMbXp7j/L1FmJuzFRYWLlPV0S1NF7aTYqq6DfiRiPQBfkzgT65BJrD9GPPNBmYD5Obm6o2Tz6GMz5izdDO3Xnoa+T3TwxX5uJzwVUUtzbo88I+qUl9RT93uOhDokNMBgK2/2Ert17WHvGVIPyWdXuN6UVxcTKe+nYiLj2s83PSkeEg/K50u47qgqpT/ufyIQ86E7ATiM+MDK65pId9lrfux3Xw1lpuzhSpcJ8VGAuWqupnAvmcE8LWIeILvmzOBLaEub+oFebz1eTnTikqYd8sZeNrRyScRwZsWOOnWVPdrujc73/BXhze7zOwrs9skn3GXcO2htwBTRGQnMAqYTmCbf6+I+IBnWjoh1lR6UjwzJgzl1j8v57nFZVw7tl94UhsT5cJ1Umw78GTw2+eajFre2mVeXNCdoiFbeHJBKeOHZdMrM+lEIhoTk6Lm0k8R4eGJwwB48OWV2BM/jDlS1BQ0QK/MJO48P5eFX5TzWslWp+MY4zpRVdAA14zpy/Be6cyY/zl7q2qdjmOMq0RdQXvihMcmFbC76gCPv3Hk1UTGtGdRV9AA+T3TuW5sP176cBMfrNvpdBxjXCMqCxrg9nMH0btTR6bNLaGmLuRPwIyJaVFb0EkJXh69rIB1X1fyX4vWOh3HGFeI2oIGOHtwFhNP6sF/Fa9hTXmF03GMcVxUFzTAA5cMJSnBy7SiEvx++2zatG9RX9BdUhK57+I8Pirbzf9+tMnpOMY4KuoLGuDyUb34Rv9OzHp9FeX79jsdxxjHxERBiwQ+m66p8/PQq587HccYx8REQQP0z0rhtm8O5LXPtvLOqqPeam1MzIuZgga44awBDM5O4YF5K6isqXM6jjERF1MFneCNY9bk4Wzdt58nF6x2Oo4xEReujiUZwLUE2gyNBx4h0NFqSJPJblXVNj+DNapPJlNO68Nzi9cz8aQejOid0darMMa1wrWHzgMWqOofgT8QaPq3Q1Wva/IK2+nouy/IJSs1kalFJdTW+1uewZgYEZaCVtUlqloS/LYz8AngFZFbROQRETklHOttkNYhnocm5LNq6z7+8P76cK7KGFcJ66NwRCQJOBO4R1VrgsMSgFdE5HJV3ReudV+Q343zh2bz1NuruTC/OzmdrWWRiX3h7MstBJoDPq2quw4bNwsoCqUv95w5c1qdYdd+P9Pfq2Zghoc7Rye26dMp3NzD2bK1jpuzhdqXG1UNywu4DugZ/LoDcF2TcS8AnVtaxuDBg/VEPffP9drn3ld17sebT3hZTS1atKhNl9eWLFvruDkbsFRDqLtwneWeDFwJjA3uFWuA5SIykcCzFV5W1Yh0JpjyjT7MXb6Fh1/9nLMHZ5GZ7J6nbxjT1sLVxrcIKArHso+XJ06YNbmAS3/zPjP/voqfXz7C6UjGhE1MXVhyLHnd07jhrP78ddlmFq/Z4XQcY8KmXRQ0wG3nDKJP5ySmzy1hf621LDKxqd0UdId4D49NKqBsZxX/uXCN03GMCYt2U9AAZwzswrdG9uJ3/1hL6TZrWWRiT7sqaID7Ls4jrWM8U4s+s5ZFJua0u4LulJzAA5fksXzjHv70wQan4xjTptpdQQNcdlJPzhzUhZ++Ucq2vdayyMSOdlnQIsLMywqo8/v5yfwVTscxps20y4IGyOmcxI/OHcybK7fzxoptTscxpk2024IGuHZsP/K6p/GT+Suo2G9PsjTRr10XdLwnjscnF1BeUcPP3ix1Oo4xJ6xdFzTAiN4ZXH16X1781waWbdjtdBxjTki7L2iAu8bn0i2tA9OtZZGJclbQQEqil0cm5lO6vYLZ765zOo4xrWYFHXTu0GwuKujGr975kvU7Kp2OY0yrWEE3MePSYSR647hvbklDZxVjokpYClpEMkTkThGZIiIvishgERkpIneLyM0icl441nuiuqZ1YOqFQ1i8did/+3iL03GMOW7h6vrZ0Je7RES2EGj61xu4UlX9IvKSiCxUVdfdmHzlKTnM/XgLj772OYW5WXROSXQ6kjEhi1Rf7hIgXlUbTiHvIlDgrhMXbFlUWVPHo6+tcjqOMcclUn25fwGc02TUXiALKDvKPE3b+FJcXBzOiMd0UV8vc5dvYYBnB/ldjvw1+Xw+x7K1xLK1jpuzhSyU1qCteQEC3Ad0AhKAuU3GPQ3ktLSMtmjj21rVB+q08OeLdOxP39Gqmrojxru55atlax03ZyPENr7hPMt9LfCcqu5S1QPAARHxBMdlAq4+69TQsmjTrmp++Y49ydJEh0j25X4cuFdEfMAz6sITYof7Rv/OXHFKb555bz0TRvRgWI90pyMZ06xI9+VeHo71hdO0C/N4e1U504pKmHvzGXji2u5xOsa0tVYdcovIqW0dxK3Sk+J58NKhfLZ5L88vLnM6jjHNOuYeWkRuBiYBh18yJUA3oCCMuVzl0uHdKfp4Mz9fUMr4/G70zOjodCRjjqq5PbRPVc9T1fMPe50HPBipgG4gIjwyMR9VeHDeCrss1LjWMQtaVV9o+r2IZIrISSKSpqpzwx/NXXp3SuLO8wfzzhflvG4ti4xLhfQeOnixx2zgCuBZEZkS1lQudc2YvuT3TOMn81dSWWt7aeM+oZ4Uy1HVy1V1qqpeDgwNZyi38nrieHzycHb6avjL6gNOxzHmCKEW9OFXVqwHEJF+bRvH/fJ7pnPt2H4Ub6rjo7JdTscx5hChfg49WkRyAD+BjcBQEckisKdud4ffd5w3mKKPyphWVMJrt40l0etpeSZjIiDUPXQ98C7wfvDf3wP/BFaGKZerJSV4uXpYAmvKffyu2FoWGfcIdQ/9P8B0AjdabAMeUNUNIvJR2JK53PAsLxNGdOa3i9Zw8fDuDOya4nQkY0LeQ98E/FBVzwfuBu4BUNWqcAWLBg9cMpSOCR6mF5XYkyyNK4Ra0J+q6tcAqrodKAUQkXbdziMrNZH7Lsrjw7JdzFm6yek4xoRc0PtEpJeI9BCRnkCdiPQAvhvGbFHh8tG9+Eb/Tjz291WUV9iTLI2zQi3om4GHgZnAo8Apwa+vDlOuqCEiPDapgP11fh5+5XOn45h2LtSTYneo6tLDB4rIoDbOE5X6Z6Vwa+FAfvHWar41spzCIV2djmTaqZD20Ecr5uDwL0OZX4JdDmLZTWcPYFDXFO6ft4LKmjqn45h2KmwtiEQkLtiXex4wJjjsGhF5psmrZ7jWH2kJ3jhmTS5gy55qnnrLWhYZZ4St66cGWvb+UUQGEriHumH4deFap9NG9+3EVafl8Id/rmfiST0p6GUti0xkRfpROFXBJ2fMEJHxEV53RNxzwRC6pCQytegz6uxJlibCJNw364vIDOBtVX2/ybA44M/AdFVdd9j0Tftyj5ozZ05Y87WWz+cjJeXoV4d9tK2O335Sw3dzE7iwX3yEkzWfzWmWrXUKCwuXqeroFicMpdfvibyAGcDYowy/Ebi8uXmd7MvdkuZ6OPv9fr32uY90yP2v68adlZELFeTm/tKWrXVwQV/uQ0jALU0G9Qdi8lkzIsLDE4cRJ3C/tSwyERTOs9xJInI5MAq4BCgEfCJyiYjcCKxR1RXhWr/TemR05O7xufxj9dfM//Qrp+OYdiKcZ7mrgL8EX+3Sv53el7mffMXDr3zO2YOzyEhKcDqSiXH2wPcw8sQJj08uYG91LY/9PSbfXRiXsYIOs7zuaVx/Vn/mLN3M4rU7nI5jYpwVdATcfs4g+nRO4r65K9hf6/pHepkoZgUdAR3iPcy8rID1Oyr57aI1TscxMcwKOkLGDurC5JN78nTxWlZvr3A6jolRVtARdN/FeaR28DLNWhaZMLGCjqDOKYncf/FQlm3YzZ8/3Oh0HBODrKAjbPLInpwxsDM/ff0Ltu+zlkWmbVlBR5iIMPOyAg7U+5kxv122NTdhZAXtgL5dkrn93EG8vmIbC1bakyxN27GCdsj1Z/ZnSLdUHnx5JRX7a52OY2KEFbRD4j2BlkXbK/bz5AJrWWTahhW0g07OyeTq0/vy/JIylm/c7XQcEwOsoB121/hcuqV1YFpRCbXWssicICtoh6Ukenl4Yj5fbKvgv9+zJ1maExORgm4PfblPxHlDs7kwvxu/evtLynZUOh3HRLFI9+XuJyL3i8j3ReSKcK07Gs2YMIwETxz3zSuxlkWm1cJW0KrqV9U/Ap9wsC/3/cCvVfV/gHNEJCNc64822WkduOfCIfxzzU6KPt7idBwTpSL9HjpbVfcFv14NnBTh9bvaVafmMKpPJo++9jm7Kg84HcdEobD1FDuGpk219gJZh09wWF9uiouLI5PsOPl8vrBkm9zLz4Mba/nhs4u4fnjrHr8drmxtwbKFV6QLuuluJxP44vAJVHU2MBsgNzdXx40bF5lkx6m4uJhwZduWWMpvFq7hPy4cxdhBXY57/nBmO1GWLbwifci9TUQaHvg0iMD7a3OYWwoH0q9LMtPnllB9wFoWmdBFrC+3iHwTeAS4Q0SuJ/B4nH3NLqSd6hDv4bFJBWzcVcWvF4b0xF5jAGf6cs8I1zpjyekDOvOd0b2Y/e46JozoQV73NKcjmShgV4q52PSL8sjoGM/UohLqrWWRCYEVtItlJCXw4KVD+XTTHl5cUuZ0HBMFrKBdbsKIHpw9OIufvVnKV3uqnY5jXM4K2uVEhEcvy8ev8ODL9iRL0zwr6CjQu1MSPz5vMG+vKueNFdayyBybFXSU+P4ZfRnWI42fzF/J3mprWWSOzgo6Sng9cTw+eTg7fDU88cYRF9gZA1hBR5WCXul8/4x+/OmDjSwt2+V0HONCVtBR5sfnDaZnRkemFZVwoM5aFplDWUFHmeREL49els+X5T5+/4+1TscxLmMFHYUKh3TlkuHd+c3CNaz92ud0HOMiVtBR6sFLh9IhPo7pRdayyBxkBR2luqZ2YPpFeXywfhd/WbrZ6TjGJaygo9h3Rvfm1H6dmPn3VXxdUeN0HOMCVtBRLC5OeGxSAdUH6nnk1c+djmNcwAo6yg3smsIthQOZ/+lXLCotdzqOcVjEC1pEikTkmeDrPyK9/lh007j+DOyawv1zV1B1oM7pOMZBTuyh56vqdcHX0w6sP+Ykej3MmlzAlj3VPPWWPcmyPYt010+AYSJyB5AM/EZV9zqQIeac0rcT3zsth2ffX8/eQfFsSixrfoZmnk7U0nOLWnqwkTSzhNWbatn64cYW5m/9+ptbd0sL37GznnEtrNvtxKnPMEXkZOBmVb3+sOFN+3KPmjNnjhPxWuTz+UhJSXE6xiEqa5WHllRTXmWfS7dGQSflzlPd9X/aoLCwcJmqjm5pOscKGkBE3lbVc481Pjc3V0tLSyMZKWRu7eF8oM7PGwv/wRljxhwxrqX/6eb+FLS5uY9j1OLFixkTzNbSn15z62w+azPjmpnx4w8/YML4wuZDOUREQiroiB5yi0ghUK6qK0UkE7C79dtYgjeOtAShc0rrnroRbpkd4shO6+B0jKNakxj9D0mN9Hvo5cAPRCQfGAJMj/D6jYlpES1oVd0D/CKS6zSmPbELS4yJIVbQxsQQR89yt0REKgB3nuaGLsAOp0Mcg2VrHTdny1XV1JYmcuLCkuNRGsqpeieIyFLLdvwsW+uIyNJQprNDbmNiiBW0MTHE7QU92+kAzbBsrWPZWiekbK4+KWaMOT5u30O7mkhL9x0ZE1muLWgRuVVErhKRB0Skk9N5GohInIhMEZF5wJF3QDhIRDJE5M5gvhdFZLDTmRqISKqI3Cci3xOR2SIy0OlMhxORNBF5w+kchzuepiCu/NhKRIYA2ar6nyKSC9wBPOBwLABU1Q/8MfgH6bY9dB6wQFVLRGQLgdtQ73I4U4Ms4F1VfU9ENgI3Anc7nKlR8GjrW4Ab+zjNV9XnQpnQlQUNnAksA1DVUhEZ6XCeqKCqS5p82xn4xKksh1PVdcC64Lf9gHcdjHM0lwKvA2c7HeQoQm4K4tZD7i7AvibfxzsVJBqJSBKBjeJfnM7SlIh0E5FfEjiSeN3pPA1EZCiwR1VdeTuvqt6tqk8BrwE/b25atxZ0OZABjYdCB5yNEz2Cv687gIdU1VXNulV1m6r+CPg98DOn8zRxCTBGRKYCBSIyNbhRdBVVXU7g6OaY3HrI/R5wDVAEDAVCuuzNAHAt8Jyquup5s8G3TeWquhmoAEY4HKmRqj7R8LWIDFHVx53M09TxNgVxZUGr6moR+UpErgH6Ar9yNtFBwS33xcAooIOIJKjqQodjASAik4ErgbHBT9RqVPVGZ1M12gJMEZGdBH531twiNMfVFMQuLDEmhrj1PbQxphWsoI2JIVbQxsQQK2hjYogVtDExxAranBARGSciT7Q8pYkEK2hzQlS1GLBr7V3CCtqYGOLKK8VMeIjIbcAeoA+Bu50mAR4Cd7YNAjaq6u9FJAW4DVgLZAMfqOoHIpJF4JbM1cAA4L9VdWdw2eOBk4EJwAWqug8TcVbQ7YSIfBPoqKq/Dt7A8QLwLHBWw722IvJXEfkr8GPgJVVdEZz2ZRGZCMwCpqtquYj0BLoBOwkc6X2qqm+KiB84HXgz0j+jsYJuTxquPb+MwF55Q3C4v8k0awnsvUcAMwBUVUWkGkgH+qhqeXD4FgLXZwP4m9x6WE6gmYFxgL2Hbj9WARWqOk9V/0awYDm068oAAk0ISoDB0Hg7ZkdgL7BDRAY1TCwiaRHIbY6D3ZzRTgQL82EgAdhMoGgBbgIWEbirbZWqvhAs1NsIFHc28C9VXSIiOcBjwJfALuD/gBzgeeCHBDqkzATqgWmqWhGZn840sIJux0RkHDBWVR91OotpG3bI3U6JSDKBrqVDRWSA03lM27A9tDExxPbQxsQQK2hjYogVtDExxAramBhiBW1MDLGCNiaG/D/3lO2Tdx60MQAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 252x180 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "train_seq2seq(net, train_iter, valid_iter, 0.0001, 5, lang, device, first_train=False)\n",
    "torch.save(encoder, MODEL_ROOT + \"trans_encoder2{}.mdl\".format(num_examples))\n",
    "torch.save(decoder, MODEL_ROOT + \"trans_decoder2{}.mdl\".format(num_examples))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_seq2seq(net, train_iter, valid_iter, 0.002, 10, lang, device, first_train=False)\n",
    "torch.save(encoder, MODEL_ROOT + \"trans_encoder2{}.mdl\".format(num_examples))\n",
    "torch.save(decoder, MODEL_ROOT + \"trans_decoder2{}.mdl\".format(num_examples))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_seq2seq(net, train_iter, valid_iter, 0.002, 10, lang, device, first_train=False)\n",
    "torch.save(encoder, MODEL_ROOT + \"trans_encoder2{}.mdl\".format(num_examples))\n",
    "torch.save(decoder, MODEL_ROOT + \"trans_decoder2{}.mdl\".format(num_examples))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def truncate_pad(line, num_steps, padding_token):\n",
    "    \"\"\"截断或填充文本序列\"\"\"\n",
    "    if len(line) > num_steps:\n",
    "        return line[:num_steps]  # 截断\n",
    "    return line + [padding_token] * (num_steps - len(line))  # 填充\n",
    "\n",
    "def predict_seq2seq(net, src_sentence, lang, num_steps,\n",
    "                    device, save_attention_weights=False):\n",
    "    def voc(line):\n",
    "        return [lang.word2idx(i) for i in line]\n",
    "\n",
    "    #\"\"\"Predict for sequence to sequence.\"\"\"\n",
    "    # Set `net` to eval mode for inference\n",
    "    net.eval()\n",
    "    src_tokens = voc((src_sentence.lower().split())) + [3]\n",
    "    enc_valid_len = torch.tensor([len(src_tokens)], device=device)\n",
    "    # 截断、填充文本序列\n",
    "    src_tokens = truncate_pad(src_tokens, num_steps, 1)\n",
    "    # Add the batch axis\n",
    "    enc_X = torch.unsqueeze(\n",
    "        torch.tensor(src_tokens, dtype=torch.long, device=device), dim=0)\n",
    "    enc_outputs = net.encoder(enc_X, enc_valid_len)\n",
    "    dec_state = net.decoder.init_state(enc_outputs, enc_valid_len)\n",
    "    # Add the batch axis\n",
    "    dec_X = torch.unsqueeze(\n",
    "        torch.tensor([2], dtype=torch.long, device=device),\n",
    "        dim=0)\n",
    "    output_seq, attention_weight_seq = [], []\n",
    "    for _ in range(num_steps):\n",
    "        Y, dec_state = net.decoder(dec_X, dec_state)\n",
    "\n",
    "        dec_X = Y.argmax(dim=2)  # 最大的下标，即为最终词下标\n",
    "        pred = dec_X.squeeze(dim=0).type(torch.int32).item()  # 去掉第一维取结果\n",
    "        # Save attention weights (to be covered later)\n",
    "        if save_attention_weights:\n",
    "            attention_weight_seq.append(net.decoder.attention_weights)\n",
    "        # Once the end-of-sequence token is predicted, the generation of the\n",
    "        # output sequence is complete\n",
    "        if pred == 3:\n",
    "            break\n",
    "        output_seq.append(pred)\n",
    "    return ' '.join([lang.index2word[i] for i in output_seq]), attention_weight_seq\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdin",
     "output_type": "stream",
     "text": [
      " hello\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "i am sorry , sir .\n"
     ]
    },
    {
     "name": "stdin",
     "output_type": "stream",
     "text": [
      " hi.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "i am sorry , sir .\n"
     ]
    },
    {
     "name": "stdin",
     "output_type": "stream",
     "text": [
      " how are you ?\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "i think i can ’ t afford it .\n"
     ]
    },
    {
     "name": "stdin",
     "output_type": "stream",
     "text": [
      " what's matter ?\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "i think i can ’ t afford it .\n"
     ]
    },
    {
     "name": "stdin",
     "output_type": "stream",
     "text": [
      " how much ?\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "i think i can ’ t afford it .\n"
     ]
    },
    {
     "name": "stdin",
     "output_type": "stream",
     "text": [
      " what ’ s your ideal job ?\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "i ’ m not sure . i ’ m sure i ’ m sure i ’ m .\n"
     ]
    },
    {
     "name": "stdin",
     "output_type": "stream",
     "text": [
      " what's your ideal boss ?\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "i think i can ’ t afford it .\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "Interrupted by user",
     "output_type": "error",
     "traceback": [
      "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[0;31mKeyboardInterrupt\u001B[0m                         Traceback (most recent call last)",
      "\u001B[0;32m<ipython-input-19-3e2491837299>\u001B[0m in \u001B[0;36m<module>\u001B[0;34m\u001B[0m\n\u001B[1;32m      1\u001B[0m \u001B[0;32mwhile\u001B[0m \u001B[0;36m1\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m----> 2\u001B[0;31m     \u001B[0msentence\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0minput\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m      3\u001B[0m     \u001B[0;32mif\u001B[0m \u001B[0msentence\u001B[0m \u001B[0;34m==\u001B[0m \u001B[0;34m'q'\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m      4\u001B[0m         \u001B[0;32mbreak\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m      5\u001B[0m     \u001B[0mout\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mweight\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mpredict_seq2seq\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mnet\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0msentence\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mlang\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mnum_steps\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mdevice\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
      "\u001B[0;32m~/.local/lib/python3.6/site-packages/ipykernel/kernelbase.py\u001B[0m in \u001B[0;36mraw_input\u001B[0;34m(self, prompt)\u001B[0m\n\u001B[1;32m    861\u001B[0m             \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0m_parent_ident\u001B[0m\u001B[0;34m,\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m    862\u001B[0m             \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0m_parent_header\u001B[0m\u001B[0;34m,\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 863\u001B[0;31m             \u001B[0mpassword\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;32mFalse\u001B[0m\u001B[0;34m,\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m    864\u001B[0m         )\n\u001B[1;32m    865\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n",
      "\u001B[0;32m~/.local/lib/python3.6/site-packages/ipykernel/kernelbase.py\u001B[0m in \u001B[0;36m_input_request\u001B[0;34m(self, prompt, ident, parent, password)\u001B[0m\n\u001B[1;32m    902\u001B[0m             \u001B[0;32mexcept\u001B[0m \u001B[0mKeyboardInterrupt\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m    903\u001B[0m                 \u001B[0;31m# re-raise KeyboardInterrupt, to truncate traceback\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m--> 904\u001B[0;31m                 \u001B[0;32mraise\u001B[0m \u001B[0mKeyboardInterrupt\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m\"Interrupted by user\"\u001B[0m\u001B[0;34m)\u001B[0m \u001B[0;32mfrom\u001B[0m \u001B[0;32mNone\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m    905\u001B[0m             \u001B[0;32mexcept\u001B[0m \u001B[0mException\u001B[0m \u001B[0;32mas\u001B[0m \u001B[0me\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m    906\u001B[0m                 \u001B[0mself\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mlog\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mwarning\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m\"Invalid Message:\"\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mexc_info\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;32mTrue\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
      "\u001B[0;31mKeyboardInterrupt\u001B[0m: Interrupted by user"
     ]
    }
   ],
   "source": [
    "while 1:\n",
    "    sentence = input()\n",
    "    if sentence == 'q':\n",
    "        break\n",
    "    out, weight = predict_seq2seq(net, sentence, lang, num_steps, device)\n",
    "    print(out)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Wed Nov 30 17:26:19 2022       \r\n",
      "+-----------------------------------------------------------------------------+\r\n",
      "| NVIDIA-SMI 470.141.03   Driver Version: 470.141.03   CUDA Version: 11.4     |\r\n",
      "|-------------------------------+----------------------+----------------------+\r\n",
      "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\r\n",
      "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\r\n",
      "|                               |                      |               MIG M. |\r\n",
      "|===============================+======================+======================|\r\n",
      "|   0  NVIDIA GeForce ...  Off  | 00000000:02:00.0 Off |                  N/A |\r\n",
      "| 54%   48C    P8     1W / 250W |   3588MiB / 11019MiB |      0%      Default |\r\n",
      "|                               |                      |                  N/A |\r\n",
      "+-------------------------------+----------------------+----------------------+\r\n",
      "|   1  NVIDIA GeForce ...  Off  | 00000000:03:00.0 Off |                  N/A |\r\n",
      "| 27%   38C    P8    20W / 250W |      3MiB / 11019MiB |      0%      Default |\r\n",
      "|                               |                      |                  N/A |\r\n",
      "+-------------------------------+----------------------+----------------------+\r\n",
      "|   2  NVIDIA GeForce ...  Off  | 00000000:82:00.0 Off |                  N/A |\r\n",
      "| 27%   32C    P8     1W / 250W |      3MiB / 11019MiB |      0%      Default |\r\n",
      "|                               |                      |                  N/A |\r\n",
      "+-------------------------------+----------------------+----------------------+\r\n",
      "|   3  NVIDIA GeForce ...  Off  | 00000000:83:00.0 Off |                  N/A |\r\n",
      "| 27%   35C    P8    18W / 250W |      3MiB / 11019MiB |      0%      Default |\r\n",
      "|                               |                      |                  N/A |\r\n",
      "+-------------------------------+----------------------+----------------------+\r\n",
      "                                                                               \r\n",
      "+-----------------------------------------------------------------------------+\r\n",
      "| Processes:                                                                  |\r\n",
      "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\r\n",
      "|        ID   ID                                                   Usage      |\r\n",
      "|=============================================================================|\r\n",
      "|    0   N/A  N/A     13028      C   ...nda3/envs/py37/bin/python     3585MiB |\r\n",
      "+-----------------------------------------------------------------------------+\r\n"
     ]
    }
   ],
   "source": [
    "!nvidia-smi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": "400"
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
